Browse Source

!84 ge submodule metadef and parser

Merge pull request !84 from taoxiangdong/master
tags/v1.1.0
lujiale Gitee 5 years ago
parent
commit
4a12843827
100 changed files with 7 additions and 30968 deletions
  1. +6
    -0
      .gitmodules
  2. +1
    -0
      metadef
  3. +0
    -79
      metadef/graph/CMakeLists.txt
  4. +0
    -371
      metadef/graph/anchor.cc
  5. +0
    -38
      metadef/graph/attr_value.cc
  6. +0
    -112
      metadef/graph/buffer.cc
  7. +0
    -1314
      metadef/graph/compute_graph.cc
  8. +0
    -147
      metadef/graph/debug/ge_log.h
  9. +0
    -69
      metadef/graph/debug/ge_op_types.h
  10. +0
    -274
      metadef/graph/debug/ge_util.h
  11. +0
    -246
      metadef/graph/debug/graph_debug.cc
  12. +0
    -48
      metadef/graph/debug/graph_debug.h
  13. +0
    -241
      metadef/graph/detail/attributes_holder.cc
  14. +0
    -508
      metadef/graph/format_refiner.cc
  15. +0
    -50
      metadef/graph/format_refiner.h
  16. +0
    -1078
      metadef/graph/ge_attr_define.cc
  17. +0
    -1289
      metadef/graph/ge_attr_value.cc
  18. +0
    -1021
      metadef/graph/ge_tensor.cc
  19. +0
    -384
      metadef/graph/graph.cc
  20. +0
    -294
      metadef/graph/graph.mk
  21. +0
    -112
      metadef/graph/inference_context.cc
  22. +0
    -190
      metadef/graph/model.cc
  23. +0
    -763
      metadef/graph/model_serialize.cc
  24. +0
    -3
      metadef/graph/module.mk
  25. +0
    -877
      metadef/graph/node.cc
  26. +0
    -1370
      metadef/graph/op_desc.cc
  27. +0
    -79
      metadef/graph/op_imp.cc
  28. +0
    -1587
      metadef/graph/operator.cc
  29. +0
    -48
      metadef/graph/operator_factory.cc
  30. +0
    -149
      metadef/graph/operator_factory_impl.cc
  31. +0
    -187
      metadef/graph/opsproto/opsproto_manager.cc
  32. +0
    -104
      metadef/graph/option/ge_context.cc
  33. +0
    -60
      metadef/graph/option/ge_local_context.cc
  34. +0
    -455
      metadef/graph/ref_relation.cc
  35. +0
    -96
      metadef/graph/runtime_inference_context.cc
  36. +0
    -688
      metadef/graph/shape_refiner.cc
  37. +0
    -704
      metadef/graph/tensor.cc
  38. +0
    -102
      metadef/graph/utils/anchor_utils.cc
  39. +0
    -1178
      metadef/graph/utils/ge_ir_utils.cc
  40. +0
    -206
      metadef/graph/utils/ge_ir_utils.h
  41. +0
    -2767
      metadef/graph/utils/graph_utils.cc
  42. +0
    -32
      metadef/graph/utils/mem_utils.h
  43. +0
    -956
      metadef/graph/utils/node_utils.cc
  44. +0
    -778
      metadef/graph/utils/op_desc_utils.cc
  45. +0
    -68
      metadef/graph/utils/string_utils.h
  46. +0
    -401
      metadef/graph/utils/tensor_utils.cc
  47. +0
    -684
      metadef/graph/utils/tuning_utils.cc
  48. +0
    -448
      metadef/graph/utils/type_utils.cc
  49. +0
    -75
      metadef/inc/external/graph/attr_value.h
  50. +0
    -38
      metadef/inc/external/graph/ge_error_codes.h
  51. +0
    -81
      metadef/inc/external/graph/graph.h
  52. +0
    -76
      metadef/inc/external/graph/inference_context.h
  53. +0
    -289
      metadef/inc/external/graph/operator.h
  54. +0
    -68
      metadef/inc/external/graph/operator_factory.h
  55. +0
    -376
      metadef/inc/external/graph/operator_reg.h
  56. +0
    -131
      metadef/inc/external/graph/tensor.h
  57. +0
    -240
      metadef/inc/external/graph/types.h
  58. +0
    -163
      metadef/inc/external/register/register.h
  59. +0
    -39
      metadef/inc/external/register/register_error_codes.h
  60. +0
    -37
      metadef/inc/external/register/register_fmk_types.h
  61. +0
    -59
      metadef/inc/external/register/register_types.h
  62. +0
    -334
      metadef/inc/external/register/scope/scope_fusion_pass_register.h
  63. +0
    -284
      metadef/inc/graph/anchor.h
  64. +0
    -191
      metadef/inc/graph/attr_value_serializable.h
  65. +0
    -82
      metadef/inc/graph/buffer.h
  66. +0
    -308
      metadef/inc/graph/compute_graph.h
  67. +0
    -1122
      metadef/inc/graph/debug/ge_attr_define.h
  68. +0
    -195
      metadef/inc/graph/def_types.h
  69. +0
    -120
      metadef/inc/graph/detail/any_map.h
  70. +0
    -165
      metadef/inc/graph/detail/attributes_holder.h
  71. +0
    -93
      metadef/inc/graph/detail/model_serialize_imp.h
  72. +0
    -343
      metadef/inc/graph/ge_attr_value.h
  73. +0
    -46
      metadef/inc/graph/ge_context.h
  74. +0
    -26
      metadef/inc/graph/ge_global_options.h
  75. +0
    -44
      metadef/inc/graph/ge_local_context.h
  76. +0
    -193
      metadef/inc/graph/ge_tensor.h
  77. +0
    -134
      metadef/inc/graph/graph_util.h
  78. +0
    -94
      metadef/inc/graph/model.h
  79. +0
    -52
      metadef/inc/graph/model_serialize.h
  80. +0
    -213
      metadef/inc/graph/node.h
  81. +0
    -328
      metadef/inc/graph/op_desc.h
  82. +0
    -48
      metadef/inc/graph/op_kernel_bin.h
  83. +0
    -56
      metadef/inc/graph/operator_factory_impl.h
  84. +0
    -46
      metadef/inc/graph/opsproto_manager.h
  85. +0
    -53
      metadef/inc/graph/range_vistor.h
  86. +0
    -79
      metadef/inc/graph/ref_relation.h
  87. +0
    -46
      metadef/inc/graph/runtime_inference_context.h
  88. +0
    -40
      metadef/inc/graph/shape_refiner.h
  89. +0
    -130
      metadef/inc/graph/tuning_utils.h
  90. +0
    -133
      metadef/inc/graph/usr_types.h
  91. +0
    -45
      metadef/inc/graph/utils/anchor_utils.h
  92. +0
    -150
      metadef/inc/graph/utils/attr_utils.h
  93. +0
    -771
      metadef/inc/graph/utils/graph_utils.h
  94. +0
    -170
      metadef/inc/graph/utils/node_utils.h
  95. +0
    -181
      metadef/inc/graph/utils/op_desc_utils.h
  96. +0
    -43
      metadef/inc/graph/utils/tensor_adapter.h
  97. +0
    -77
      metadef/inc/graph/utils/tensor_utils.h
  98. +0
    -53
      metadef/inc/graph/utils/type_utils.h
  99. +0
    -127
      metadef/proto/dump_task.proto
  100. +0
    -26
      metadef/proto/fusion_model.proto

+ 6
- 0
.gitmodules View File

@@ -0,0 +1,6 @@
[submodule "parser"]
path = parser
url = https://gitee.com/ascend/parser.git
[submodule "metadef"]
path = metadef
url = https://gitee.com/ascend/metadef.git

+ 1
- 0
metadef

@@ -0,0 +1 @@
Subproject commit d097f7ce4e7a3ec449e13df78e26e8a76ca48cec

+ 0
- 79
metadef/graph/CMakeLists.txt View File

@@ -1,79 +0,0 @@
# Copyright 2019-2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

# libgraph.so
# compiling proto files generates some warnings, use no-unused-variable to suppress them
set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}")
# add all proto files, generate corresponding .h and .cc files
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"${GE_SOURCE_DIR}/metadef/proto/om.proto"
"${GE_SOURCE_DIR}/metadef/proto/ge_ir.proto"
"${GE_SOURCE_DIR}/metadef/proto/insert_op.proto"
"${GE_SOURCE_DIR}/metadef/proto/task.proto"
"${GE_SOURCE_DIR}/metadef/proto/fwk_adaper.proto"
"${GE_SOURCE_DIR}/metadef/proto/op_mapping_info.proto"
"${GE_SOURCE_DIR}/metadef/proto/dump_task.proto"
)

file(GLOB_RECURSE ONNX_PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"${onnx_INC}/onnx/onnx.proto"
)

ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
ge_protobuf_generate(ge PROTO_ONNX_SRCS PROTO_ONNX_HDRS ${ONNX_PROTO_LIST})

# need to remove dependencies on pb files later
file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"*.cc"
"utils/*.cc"
"opsproto/*.cc"
"detail/*.cc"
"debug/*.cc"
"option/*.cc"
)

# include directories
include_directories(${CMAKE_CURRENT_LIST_DIR})
include_directories(${GE_SOURCE_DIR})
#include_directories(${GE_SOURCE_DIR}/src)
include_directories(${GE_SOURCE_DIR}/ge)
include_directories(${GE_SOURCE_DIR}/metadef)
include_directories(${GE_SOURCE_DIR}/metadef/graph)
include_directories(${GE_SOURCE_DIR}/inc)
include_directories(${GE_SOURCE_DIR}/inc/framework)
include_directories(${GE_SOURCE_DIR}/inc/external)
include_directories(${GE_SOURCE_DIR}/metadef/inc)
include_directories(${GE_SOURCE_DIR}/metadef/inc/external/graph)
include_directories(${GE_SOURCE_DIR}/metadef/inc/external)
include_directories(${GE_SOURCE_DIR}/metadef/inc/graph)
include_directories(${GE_SOURCE_DIR}/inc/common)
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc)
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops)
include_directories(${CMAKE_BINARY_DIR})
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
include_directories(${GE_SOURCE_DIR}/build)

######### libgraph.so #############
add_library(graph SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_ONNX_SRCS})
target_compile_definitions(graph PRIVATE
DAVINCI_CLOUD
Werror)
target_link_libraries(graph PRIVATE
${PROTOBUF_LIBRARY}
${c_sec}
${slog}
${error_manager}
rt
dl)

+ 0
- 371
metadef/graph/anchor.cc View File

@@ -1,371 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/anchor.h"
#include <algorithm>
#include <cstring>
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/node.h"

namespace ge {
Anchor::Anchor(const NodePtr &owner_node, int idx) : owner_node_(owner_node), idx_(idx) {}

bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf<Anchor>(), type) == 0; }

size_t Anchor::GetPeerAnchorsSize() const { return peer_anchors_.size(); }

Anchor::Vistor<AnchorPtr> Anchor::GetPeerAnchors() const {
vector<AnchorPtr> ret;
for (const auto &anchor : peer_anchors_) {
ret.push_back(anchor.lock());
}
return Anchor::Vistor<AnchorPtr>(shared_from_this(), ret);
}

AnchorPtr Anchor::GetFirstPeerAnchor() const {
if (peer_anchors_.empty()) {
return nullptr;
} else {
return Anchor::DynamicAnchorCast<Anchor>(peer_anchors_.begin()->lock());
}
}

NodePtr Anchor::GetOwnerNode() const { return owner_node_.lock(); }

void Anchor::UnlinkAll() noexcept {
if (!peer_anchors_.empty()) {
do {
auto peer_anchor_ptr = peer_anchors_.begin()->lock();
if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) {
GELOGW("unlink peer_anchor_ptr failed.");
}
} while (!peer_anchors_.empty());
}
}

graphStatus Anchor::Unlink(const AnchorPtr &peer) {
if (peer == nullptr) {
GELOGE(GRAPH_FAILED, "peer anchor is invalid.");
return GRAPH_FAILED;
}
auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr<Anchor> &an) {
auto anchor = an.lock();
return peer->Equal(anchor);
});

GE_IF_BOOL_EXEC(it == peer_anchors_.end(), GELOGW("this anchor is not connected to peer"); return GRAPH_FAILED);

auto it_peer =
std::find_if(peer->peer_anchors_.begin(), peer->peer_anchors_.end(), [this](const std::weak_ptr<Anchor> &an) {
auto anchor = an.lock();
return Equal(anchor);
});

GE_CHK_BOOL_RET_STATUS(it_peer != peer->peer_anchors_.end(), GRAPH_FAILED, "peer is not connected to this anchor");

(void)peer_anchors_.erase(it);
(void)peer->peer_anchors_.erase(it_peer);
return GRAPH_SUCCESS;
}

graphStatus Anchor::ReplacePeer(const AnchorPtr &old_peer, const AnchorPtr &first_peer, const AnchorPtr &second_peer) {
GE_CHK_BOOL_RET_STATUS(old_peer != nullptr, GRAPH_FAILED, "this old peer anchor is nullptr");
GE_CHK_BOOL_RET_STATUS(first_peer != nullptr, GRAPH_FAILED, "this first peer anchor is nullptr");
GE_CHK_BOOL_RET_STATUS(second_peer != nullptr, GRAPH_FAILED, "this second peer anchor is nullptr");
auto this_it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [old_peer](const std::weak_ptr<Anchor> &an) {
auto anchor = an.lock();
return old_peer->Equal(anchor);
});

GE_CHK_BOOL_RET_STATUS(this_it != peer_anchors_.end(), GRAPH_FAILED, "this anchor is not connected to old_peer");

auto old_it = std::find_if(old_peer->peer_anchors_.begin(), old_peer->peer_anchors_.end(),
[this](const std::weak_ptr<Anchor> &an) {
auto anchor = an.lock();
return Equal(anchor);
});

GE_CHK_BOOL_RET_STATUS(old_it != old_peer->peer_anchors_.end(), GRAPH_FAILED,
"old_peer is not connected to this anchor");
*this_it = first_peer;
first_peer->peer_anchors_.push_back(shared_from_this());
*old_it = second_peer;
second_peer->peer_anchors_.push_back(old_peer);
return GRAPH_SUCCESS;
}

bool Anchor::IsLinkedWith(const AnchorPtr &peer) {
auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr<Anchor> &an) {
auto anchor = an.lock();
GE_CHK_BOOL_RET_STATUS(peer != nullptr, false, "this old peer anchor is nullptr");
return peer->Equal(anchor);
});
return (it != peer_anchors_.end());
}

int Anchor::GetIdx() const { return idx_; }

void Anchor::SetIdx(int index) { idx_ = index; }

DataAnchor::DataAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {}

bool DataAnchor::IsTypeOf(TYPE type) const {
if (strcmp(Anchor::TypeOf<DataAnchor>(), type) == 0) {
return true;
}
return Anchor::IsTypeOf(type);
}

InDataAnchor::InDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {}

OutDataAnchorPtr InDataAnchor::GetPeerOutAnchor() const {
if (peer_anchors_.empty()) {
return nullptr;
} else {
return Anchor::DynamicAnchorCast<OutDataAnchor>(peer_anchors_.begin()->lock());
}
}

graphStatus InDataAnchor::LinkFrom(const OutDataAnchorPtr &src) {
// InDataAnchor must be only linkfrom once
if (src == nullptr || !peer_anchors_.empty()) {
GELOGE(GRAPH_FAILED, "src anchor is invalid or the peerAnchors is not empty.");
return GRAPH_FAILED;
}
peer_anchors_.push_back(src);
src->peer_anchors_.push_back(shared_from_this());
return GRAPH_SUCCESS;
}

bool InDataAnchor::Equal(AnchorPtr anchor) const {
auto in_data_anchor = Anchor::DynamicAnchorCast<InDataAnchor>(anchor);
if (in_data_anchor != nullptr) {
if (GetOwnerNode() == in_data_anchor->GetOwnerNode() && GetIdx() == in_data_anchor->GetIdx()) {
return true;
}
}
return false;
}

bool InDataAnchor::IsTypeOf(TYPE type) const {
if (strcmp(Anchor::TypeOf<InDataAnchor>(), type) == 0) {
return true;
}
return DataAnchor::IsTypeOf(type);
}

OutDataAnchor::OutDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {}

OutDataAnchor::Vistor<InDataAnchorPtr> OutDataAnchor::GetPeerInDataAnchors() const {
vector<InDataAnchorPtr> ret;
for (const auto &anchor : peer_anchors_) {
auto in_data_anchor = Anchor::DynamicAnchorCast<InDataAnchor>(anchor.lock());
if (in_data_anchor != nullptr) {
ret.push_back(in_data_anchor);
}
}
return OutDataAnchor::Vistor<InDataAnchorPtr>(shared_from_this(), ret);
}

uint32_t OutDataAnchor::GetPeerInDataNodesSize() const {
uint32_t out_nums = 0;
for (const auto &anchor : peer_anchors_) {
auto in_data_anchor = Anchor::DynamicAnchorCast<InDataAnchor>(anchor.lock());
if (in_data_anchor != nullptr && in_data_anchor->GetOwnerNode() != nullptr) {
out_nums++;
}
}
return out_nums;
}

OutDataAnchor::Vistor<InControlAnchorPtr> OutDataAnchor::GetPeerInControlAnchors() const {
vector<InControlAnchorPtr> ret;
for (const auto &anchor : peer_anchors_) {
auto in_control_anchor = Anchor::DynamicAnchorCast<InControlAnchor>(anchor.lock());
if (in_control_anchor != nullptr) {
ret.push_back(in_control_anchor);
}
}
return OutDataAnchor::Vistor<InControlAnchorPtr>(shared_from_this(), ret);
}

graphStatus OutDataAnchor::LinkTo(const InDataAnchorPtr &dest) {
if (dest == nullptr || !dest->peer_anchors_.empty()) {
GELOGE(GRAPH_FAILED, "dest anchor is invalid or the peerAnchors is not empty.");
return GRAPH_FAILED;
}
peer_anchors_.push_back(dest);
dest->peer_anchors_.push_back(shared_from_this());
return GRAPH_SUCCESS;
}

graphStatus OutDataAnchor::LinkTo(const InControlAnchorPtr &dest) {
if (dest == nullptr) {
GELOGE(GRAPH_FAILED, "dest anchor is invalid.");
return GRAPH_FAILED;
}
peer_anchors_.push_back(dest);
dest->peer_anchors_.push_back(shared_from_this());
return GRAPH_SUCCESS;
}

graphStatus OutControlAnchor::LinkTo(const InDataAnchorPtr &dest) {
if (dest == nullptr) {
GELOGE(GRAPH_FAILED, "dest anchor is invalid.");
return GRAPH_FAILED;
}
peer_anchors_.push_back(dest);
dest->peer_anchors_.push_back(shared_from_this());
return GRAPH_SUCCESS;
}

bool OutDataAnchor::Equal(AnchorPtr anchor) const {
CHECK_FALSE_EXEC(anchor != nullptr, return false);
auto out_data_anchor = Anchor::DynamicAnchorCast<OutDataAnchor>(anchor);
if (out_data_anchor != nullptr) {
if (GetOwnerNode() == out_data_anchor->GetOwnerNode() && GetIdx() == out_data_anchor->GetIdx()) {
return true;
}
}
return false;
}

bool OutDataAnchor::IsTypeOf(TYPE type) const {
if (strcmp(Anchor::TypeOf<OutDataAnchor>(), type) == 0) {
return true;
}
return DataAnchor::IsTypeOf(type);
}

ControlAnchor::ControlAnchor(const NodePtr &owner_node) : Anchor(owner_node, -1) {}

ControlAnchor::ControlAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {}

bool ControlAnchor::IsTypeOf(TYPE type) const {
if (strcmp(Anchor::TypeOf<ControlAnchor>(), type) == 0) {
return true;
}
return Anchor::IsTypeOf(type);
}

InControlAnchor::InControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {}

InControlAnchor::InControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {}

InControlAnchor::Vistor<OutControlAnchorPtr> InControlAnchor::GetPeerOutControlAnchors() const {
vector<OutControlAnchorPtr> ret;
for (const auto &anchor : peer_anchors_) {
auto out_control_anchor = Anchor::DynamicAnchorCast<OutControlAnchor>(anchor.lock());
if (out_control_anchor != nullptr) {
ret.push_back(out_control_anchor);
}
}
return InControlAnchor::Vistor<OutControlAnchorPtr>(shared_from_this(), ret);
}

InControlAnchor::Vistor<OutDataAnchorPtr> InControlAnchor::GetPeerOutDataAnchors() const {
vector<OutDataAnchorPtr> ret;
for (const auto &anchor : peer_anchors_) {
auto out_data_anchor = Anchor::DynamicAnchorCast<OutDataAnchor>(anchor.lock());
if (out_data_anchor != nullptr) {
ret.push_back(out_data_anchor);
}
}
return InControlAnchor::Vistor<OutDataAnchorPtr>(shared_from_this(), ret);
}

graphStatus InControlAnchor::LinkFrom(const OutControlAnchorPtr &src) {
if (src == nullptr) {
GELOGE(GRAPH_FAILED, "src anchor is invalid.");
return GRAPH_FAILED;
}
peer_anchors_.push_back(src);
src->peer_anchors_.push_back(shared_from_this());
return GRAPH_SUCCESS;
}

bool InControlAnchor::Equal(AnchorPtr anchor) const {
CHECK_FALSE_EXEC(anchor != nullptr, return false);
auto in_control_anchor = Anchor::DynamicAnchorCast<InControlAnchor>(anchor);
if (in_control_anchor != nullptr) {
if (GetOwnerNode() == in_control_anchor->GetOwnerNode()) {
return true;
}
}
return false;
}

bool InControlAnchor::IsTypeOf(TYPE type) const {
if (strcmp(Anchor::TypeOf<InControlAnchor>(), type) == 0) {
return true;
}
return ControlAnchor::IsTypeOf(type);
}

OutControlAnchor::OutControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {}

OutControlAnchor::OutControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {}

OutControlAnchor::Vistor<InControlAnchorPtr> OutControlAnchor::GetPeerInControlAnchors() const {
vector<InControlAnchorPtr> ret;
for (const auto &anchor : peer_anchors_) {
auto in_control_anchor = Anchor::DynamicAnchorCast<InControlAnchor>(anchor.lock());
if (in_control_anchor != nullptr) {
ret.push_back(in_control_anchor);
}
}
return OutControlAnchor::Vistor<InControlAnchorPtr>(shared_from_this(), ret);
}

OutControlAnchor::Vistor<InDataAnchorPtr> OutControlAnchor::GetPeerInDataAnchors() const {
vector<InDataAnchorPtr> ret;
for (const auto &anchor : peer_anchors_) {
auto in_data_anchor = Anchor::DynamicAnchorCast<InDataAnchor>(anchor.lock());
if (in_data_anchor != nullptr) {
ret.push_back(in_data_anchor);
}
}
return OutControlAnchor::Vistor<InDataAnchorPtr>(shared_from_this(), ret);
}

graphStatus OutControlAnchor::LinkTo(const InControlAnchorPtr &dest) {
if (dest == nullptr) {
GELOGE(GRAPH_FAILED, "dest anchor is invalid.");
return GRAPH_FAILED;
}
peer_anchors_.push_back(dest);
dest->peer_anchors_.push_back(shared_from_this());
return GRAPH_SUCCESS;
}

bool OutControlAnchor::Equal(AnchorPtr anchor) const {
auto out_control_anchor = Anchor::DynamicAnchorCast<OutControlAnchor>(anchor);
if (out_control_anchor != nullptr) {
if (GetOwnerNode() == out_control_anchor->GetOwnerNode()) {
return true;
}
}
return false;
}

bool OutControlAnchor::IsTypeOf(TYPE type) const {
if (strcmp(Anchor::TypeOf<OutControlAnchor>(), type) == 0) {
return true;
}
return ControlAnchor::IsTypeOf(type);
}
} // namespace ge

+ 0
- 38
metadef/graph/attr_value.cc View File

@@ -1,38 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "external/graph/attr_value.h"
#include "debug/ge_log.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/ge_attr_value.h"

namespace ge {
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue::AttrValue() { impl = ComGraphMakeShared<AttrValueImpl>(); }

#define ATTR_VALUE_SET_GET_IMP(type) \
graphStatus AttrValue::GetValue(type &val) const { \
if (impl != nullptr) { \
GELOGW("GetValue failed."); \
return impl->geAttrValue_.GetValue<type>(val); \
} \
return GRAPH_FAILED; \
}

ATTR_VALUE_SET_GET_IMP(AttrValue::STR)
ATTR_VALUE_SET_GET_IMP(AttrValue::INT)
ATTR_VALUE_SET_GET_IMP(AttrValue::FLOAT)
} // namespace ge

+ 0
- 112
metadef/graph/buffer.cc View File

@@ -1,112 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/buffer.h"
#include "proto/ge_ir.pb.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
Buffer::Buffer() {
data_.InitDefault();
if (data_.GetProtoMsg()) {
buffer_ = data_.GetProtoMsg()->mutable_bt();
}
}

Buffer::Buffer(const Buffer &other) {
// Share data
data_ = other.data_;
buffer_ = other.buffer_;
}

Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { // default
auto proto_msg = data_.GetProtoMsg();
if (proto_msg != nullptr) {
try {
proto_msg->set_bt(std::string(buffer_size, default_val));
buffer_ = proto_msg->mutable_bt();
} catch (std::bad_alloc &e) {
GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size);
buffer_ = nullptr;
}
}
}

Buffer Buffer::CopyFrom(const std::uint8_t *data, std::size_t buffer_size) {
Buffer buffer;
auto proto_msg = buffer.data_.GetProtoMsg();
if (proto_msg != nullptr && data != nullptr) {
try {
proto_msg->set_bt(data, buffer_size);
buffer.buffer_ = proto_msg->mutable_bt();
} catch (std::bad_alloc &e) {
GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size);
buffer.buffer_ = nullptr;
}
}
return buffer;
}

Buffer::Buffer(const std::shared_ptr<google::protobuf::Message> &proto_owner, proto::AttrDef *buffer)
: data_(proto_owner, buffer) {
if (data_.GetProtoMsg() != nullptr) {
buffer_ = data_.GetProtoMsg()->mutable_bt();
}
}

Buffer::Buffer(const std::shared_ptr<google::protobuf::Message> &proto_owner, std::string *buffer)
: data_(proto_owner, nullptr) {
buffer_ = buffer;
}

Buffer &Buffer::operator=(const Buffer &other) {
if (&other != this) {
// Share data
data_ = other.data_;
buffer_ = other.buffer_;
}
return *this;
}

const std::uint8_t *Buffer::GetData() const {
if (buffer_ != nullptr) {
return (const std::uint8_t *)buffer_->data();
}
return nullptr;
}

std::uint8_t *Buffer::GetData() {
if (buffer_ != nullptr && !buffer_->empty()) {
// Avoid copy on write
(void)(*buffer_)[0];
return reinterpret_cast<uint8_t *>(const_cast<char *>(buffer_->data()));
}
return nullptr;
}

std::size_t Buffer::GetSize() const {
if (buffer_ != nullptr) {
return buffer_->size();
}
return 0;
}

void Buffer::ClearBuffer() {
if (buffer_ != nullptr) {
buffer_->clear();
}
}
} // namespace ge

+ 0
- 1314
metadef/graph/compute_graph.cc
File diff suppressed because it is too large
View File


+ 0
- 147
metadef/graph/debug/ge_log.h View File

@@ -1,147 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMMON_GRAPH_DEBUG_GE_LOG_H_
#define COMMON_GRAPH_DEBUG_GE_LOG_H_

#include "graph/ge_error_codes.h"
#include "framework/common/debug/ge_log.h"

#define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__)

#define GE_LOGI_IF(condition, ...) \
if ((condition)) { \
GELOGI(__VA_ARGS__); \
}

#define GE_LOGW_IF(condition, ...) \
if ((condition)) { \
GELOGW(__VA_ARGS__); \
}

#define GE_LOGE_IF(condition, ...) \
if ((condition)) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
}

#define GE_CHK_STATUS_RET_NOLOG(expr) \
do { \
const ge::graphStatus _status = (expr); \
if (ge::SUCCESS != _status) { \
return _status; \
} \
} while (0)

#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \
do { \
bool b = (expr); \
if (!b) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
return _status; \
} \
} while (0)

#define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \
{ \
bool b = (expr); \
if (!b) { \
exec_expr; \
} \
}

#define GE_IF_BOOL_EXEC(expr, exec_expr) \
{ \
if (expr) { \
exec_expr; \
} \
}

#define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \
do { \
const ge::graphStatus _status = (expr); \
if (_status) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
return _status; \
} \
} while (0)

// If expr is true, the log is printed and a custom statement is executed
#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (b) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
exec_expr; \
} \
}

// Only check error log
#define GE_CHK_BOOL_ONLY_LOG(expr, ...) \
do { \
bool b = (expr); \
if (!b) { \
GELOGI(__VA_ARGS__); \
} \
} while (0)

// If expr is not true, do not print the log and return the specified status
#define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \
do { \
bool b = (expr); \
if (!b) { \
return _status; \
} \
} while (0)

// If expr is not true, the log is printed and a custom statement is executed
#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (!b) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
exec_expr; \
} \
}

// If expr is not true, the log is printed and a custom statement is executed
#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (!b) { \
GELOGI(__VA_ARGS__); \
exec_expr; \
} \
}

// If expr is not GRAPH_SUCCESS, print the log and return the same value
#define GE_CHK_STATUS_RET(expr, ...) \
do { \
const ge::graphStatus _status = (expr); \
if (ge::SUCCESS != _status) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
return _status; \
} \
} while (0)

#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \
try { \
exec_expr0; \
} catch (...) { \
GELOGE(ge::FAILED, "Make shared failed"); \
exec_expr1; \
}

#endif // COMMON_GRAPH_DEBUG_GE_LOG_H_

+ 0
- 69
metadef/graph/debug/ge_op_types.h View File

@@ -1,69 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_
#define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_

namespace ge {
#define GE_REGISTER_OPTYPE(var_name, str_name) static const char *var_name __attribute__((unused)) = str_name

GE_REGISTER_OPTYPE(DATA, "Data");
GE_REGISTER_OPTYPE(AIPPDATA, "AippData");
GE_REGISTER_OPTYPE(MATMUL, "MatMul");
GE_REGISTER_OPTYPE(RESHAPE, "Reshape");
GE_REGISTER_OPTYPE(PERMUTE, "Permute");
GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput");
GE_REGISTER_OPTYPE(_WHILE, "_While");
GE_REGISTER_OPTYPE(WHILE, "While");
GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile");
GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze");
GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims");
GE_REGISTER_OPTYPE(SWITCH, "Switch");
GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch");
GE_REGISTER_OPTYPE(SWITCHN, "SwitchN");
GE_REGISTER_OPTYPE(MERGE, "Merge");
GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge");
GE_REGISTER_OPTYPE(ENTER, "Enter");
GE_REGISTER_OPTYPE(REFENTER, "RefEnter");
GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration");
GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration");
GE_REGISTER_OPTYPE(CONSTANT, "Const");
GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder");
GE_REGISTER_OPTYPE(END, "End");
GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp");
GE_REGISTER_OPTYPE(GETNEXT, "GetNext");
GE_REGISTER_OPTYPE(INITDATA, "InitData");
GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity");
GE_REGISTER_OPTYPE(ANN_DATA, "AnnData");

GE_REGISTER_OPTYPE(CONSTANTOP, "Constant");
GE_REGISTER_OPTYPE(VARIABLE, "Variable");
GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2");

GE_REGISTER_OPTYPE(INPUT_TYPE, "Input");

// Horovod operator
GE_REGISTER_OPTYPE(HVDCALLBACKALLREDUCE, "hvdCallbackAllreduce");
GE_REGISTER_OPTYPE(HVDCALLBACKALLGATHER, "hvdCallbackAllgather");
GE_REGISTER_OPTYPE(HVDCALLBACKBROADCAST, "hvdCallbackBroadcast");
GE_REGISTER_OPTYPE(HVDWAIT, "hvdWait");

GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output");

GE_REGISTER_OPTYPE(RECV, "Recv");
GE_REGISTER_OPTYPE(SEND, "Send");
}; // namespace ge
#endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_

+ 0
- 274
metadef/graph/debug/ge_util.h View File

@@ -1,274 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMMON_GRAPH_DEBUG_GE_UTIL_H_
#define COMMON_GRAPH_DEBUG_GE_UTIL_H_

#include <limits.h>
#include <math.h>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_log.h"
#include "graph/ge_error_codes.h"

#if !defined(__ANDROID__) && !defined(ANDROID)
#define GE_DYNAMIC_CAST dynamic_cast
#define GE_DYNAMIC_POINTER_CAST std::dynamic_pointer_cast
#else
#define GE_DYNAMIC_CAST static_cast
#define GE_DYNAMIC_POINTER_CAST std::static_pointer_cast
#endif

#define GE_RETURN_IF_ERROR(expr) \
do { \
const ::ge::optStatus _status = (expr); \
if (_status) return _status; \
} while (0)

#define GE_RETURN_WITH_LOG_IF_INFO(expr, ...) \
do { \
const ::ge::optStatus _status = (expr); \
if (_status) { \
GELOGI(__VA_ARGS__); \
return _status; \
} \
} while (0)

// Verify whether the parameter is true. If yes, return graph failed and record the error log
#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \
do { \
if (condition) { \
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \
return ge::GRAPH_FAILED; \
} \
} while (0)

// Verify whether the parameter is false. If yes, return graph failed and record the error log
#define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \
do { \
bool _condition = (condition); \
if (!_condition) { \
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \
return ge::GRAPH_FAILED; \
} \
} while (0)

// Verify whether the parameter is true. If yes, return GRAPH_PARAM_INVALID and record the error log
#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \
do { \
if (condition) { \
GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

// Verify whether the parameter is false. If yes, return GRAPH_PARAM_INVALID and record the error log
#define GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(condition, ...) \
do { \
bool _condition = (condition); \
if (!_condition) { \
GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log
#define GE_CHECK_NOTNULL(val) \
do { \
if (val == nullptr) { \
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log
#define GE_CHECK_NOTNULL_EXEC(val, expr) \
do { \
if (val == nullptr) { \
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \
expr; \
} \
} while (0)

// Verify whether the parameter is null. If yes, return false and record the error log
#define GE_RT_FALSE_CHECK_NOTNULL(val) \
do { \
if (val == nullptr) { \
GELOGE(ge::GRAPH_FAILED, "param[%s] must not be null.", #val); \
return false; \
} \
} while (0)

// Check whether the parameter is out of range
#define GE_CHECK_SIZE(size) \
do { \
if (size == 0) { \
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

///
/// @ingroup GE_common
/// eg:GE_DEFINE_BYTE_SIZE(filter_byte, filter.data().size(), sizeof(float));
///
#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \
uint32_t _var_name; \
do { \
uint32_t _expr_size = (_expr); \
uint32_t _sizeof_size = (_sizeof); \
if (_expr_size > (0xffffffff) / _sizeof_size) { \
GELOGE(ge::GRAPH_PARAM_INVALID, "byte size : %s is out of range", #_var_name); \
return ge::GRAPH_PARAM_INVALID; \
} \
_var_name = _sizeof_size * _expr_size; \
} while (0);

// Check whether the container is empty
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \
do { \
if (vector.empty()) { \
GELOGE(ge::GRAPH_FAILED, "param[#vector] is empty", #vector); \
return ge::GRAPH_FAILED; \
} \
} while (0)

// Check whether the container is empty and return the specified status code
#define GE_CHECK_VECTOR_NOT_EMPTY_RET_STATUS(vector, _status) \
do { \
if (vector.empty()) { \
GELOGE(_status, "param[%s] is empty", #vector); \
return _status; \
} \
} while (0)

///
/// @ingroup GE_common
/// @brief This macro provides the ability to disable copying constructors and assignment operators.
/// It is usually placed under private
///
#define GE_DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName &) = delete; \
void operator=(const TypeName &) = delete

/// Check whether the size is 0 or out of range
/// @param:size:Size to be verified
#define GE_CHECK_SIZE_RANGE(size) \
do { \
if (size == 0 || size >= UINT_MAX / 4) { \
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

#define GE_CHECK_SHORT_SIZE_RANGE(size) \
do { \
if (size == 0 || size >= UINT_MAX / 2) { \
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \
do { \
if (size <= 0) { \
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not a positive number", #size); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

#define GE_CHECK_POSITIVE_SHORT_SIZE_RANGE(size) \
do { \
if (size <= 0 || size == 0 || size >= UINT_MAX / 4) { \
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

// Verify that the value on the left is greater than or equal to the value on the right
#define GE_CHECK_GE(lhs, rhs) \
do { \
if (lhs < rhs) { \
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is less than[%s]", #lhs, #rhs); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

// Check whether the parameters are equal
#define GE_CHECK_EQ(val1, val2) \
do { \
if (val1 != val2) { \
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not equals to[%s]", #val1, #val2); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

// Verify that the value on the left is less than or equal to the value on the right
#define GE_CHECK_LE(lhs, rhs) \
do { \
if (lhs > rhs) { \
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is greater than[%s]", #lhs, #rhs); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

// Check whether the parameters are equal
#define GE_CHECK_EQ_WITH_LOG(val1, val2, ...) \
do { \
if (val1 != val2) { \
GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \
return ge::GRAPH_PARAM_INVALID; \
} \
} while (0)

// If expr is false, the custom statement is executed
#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \
do { \
bool b = (expr); \
if (!b) { \
exec_expr; \
} \
} while (0)

#define GE_DELETE_NEW_SINGLE(var) \
do { \
if (var != nullptr) { \
delete var; \
var = nullptr; \
} \
} while (0)

#define GE_DELETE_NEW_ARRAY(var) \
do { \
if (var != nullptr) { \
delete[] var; \
var = nullptr; \
} \
} while (0)

template <typename T, typename... Args>
static inline std::shared_ptr<T> ComGraphMakeShared(Args &&... args) {
using T_nc = typename std::remove_const<T>::type;
std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...));
return ret;
}

#endif // COMMON_GRAPH_DEBUG_GE_UTIL_H_

+ 0
- 246
metadef/graph/debug/graph_debug.cc View File

@@ -1,246 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/debug/graph_debug.h"
#include <algorithm>
#include <unordered_set>
#include <vector>
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"

#define TAB " "
#define STR_FMT(str) (" \"" + std::string(str) + "\" ")
#define INPUT_ANCHOR_PORT(name) ("__input__" + (name))
#define OUTPUT_ANCHOR_PORT(name) ("__output__" + (name))

namespace ge {
std::unordered_set<std::string> control_anchor;
std::vector<string> types = {
"DT_FLOAT", "DT_FLOAT16", "DT_INT8", "DT_INT32", "DT_UINT8", "",
"DT_INT16", "DT_UINT16", "DT_UINT32", "DT_INT64", "DT_UINT64", "DT_DOUBLE",
"DT_BOOL", "DT_DUAL", "DT_DUAL_SUB_INT8", "DT_DUAL_SUB_UINT8", "DT_UNDEFINED"};

std::vector<string> formats = {"FORMAT_NCHW",
"FORMAT_NHWC",
"FORMAT_ND",
"FORMAT_NC1HWC0",
"FORMAT_FRACTAL_Z",
"FORMAT_NC1C0HWPAD",
"FORMAT_NHWC1C0",
"FORMAT_FSR_NCHW",
"FORMAT_FRACTAL_DECONV",
"FORMAT_C1HWNC0",
"FORMAT_FRACTAL_DECONV_TRANSPOSE",
"FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS",
"FORMAT_NC1HWC0_C04",
"FORMAT_FRACTAL_Z_C04",
"FORMAT_CHWN",
"FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS",
"FORMAT_HWCN",
"FORMAT_NC1KHKWHWC0",
"FORMAT_BN_WEIGHT",
"FORMAT_FILTER_HWCK",
"FORMAT_HASHTABLE_LOOKUP_LOOKUPS",
"FORMAT_HASHTABLE_LOOKUP_KEYS",
"FORMAT_HASHTABLE_LOOKUP_VALUE",
"FORMAT_HASHTABLE_LOOKUP_OUTPUT",
"FORMAT_HASHTABLE_LOOKUP_HITS",
"FORMAT_RESERVED"};

std::vector<string> data_nodes = {"Const", "Data"};

void GraphDebugPrinter::DumpNodeToDot(const NodePtr node, std::ostringstream &out_) {
if (node == nullptr) {
GELOGI("Some nodes are null.");
return;
}

bool in_control = false;
auto name = node->GetName();
out_ << TAB << STR_FMT(name);
auto input_cnt = std::max(static_cast<size_t>(1), node->GetAllInDataAnchors().size());
auto output_cnt = std::max(static_cast<size_t>(1), node->GetAllOutDataAnchors().size());
if (control_anchor.find(node->GetName()) != control_anchor.end()) {
input_cnt++;
in_control = true;
}
auto max_col = input_cnt * output_cnt;
out_ << "[\n";
if (find(data_nodes.begin(), data_nodes.end(), node->GetType()) != data_nodes.end()) {
out_ << TAB << TAB << "shape=plaintext, color=goldenrod\n";
} else {
out_ << TAB << TAB << "shape=plaintext, color=deepskyblue\n";
}
out_ << TAB << TAB << "label=<\n";
out_ << TAB << TAB << R"(<table border="0" cellborder="1" align="center")"
<< ">" << std::endl;

auto input_anchors = node->GetAllInDataAnchors();
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return );
if (!input_anchors.empty()) {
out_ << TAB << TAB << "<tr>";
}
for (const auto &anchor : input_anchors) {
string anchor_text = op_desc->GetInputNameByIndex(anchor->GetIdx());

out_ << "<td port = " << STR_FMT(INPUT_ANCHOR_PORT(anchor_text)) << " colspan='" << output_cnt << "'>"
<< anchor_text << "</td>";
}
if (in_control) {
string anchor_text = "ctrl";
out_ << "<td port = " << STR_FMT(INPUT_ANCHOR_PORT(anchor_text)) << " colspan='" << output_cnt << "'>"
<< anchor_text << "</td>";
}
if (!input_anchors.empty()) {
out_ << "</tr>\n";
}
// Node type
out_ << TAB << TAB << "<tr><td colspan='" << max_col << "'>"
<< "<b>" << node->GetType() << "</b></td></tr>\n";
// Output
auto output_anchors = node->GetAllOutDataAnchors();
if (!output_anchors.empty()) {
out_ << TAB << TAB << "<tr>";
}
for (const auto &anchor : output_anchors) {
string anchor_text = op_desc->GetOutputNameByIndex(anchor->GetIdx());

out_ << "<td port = " << STR_FMT(OUTPUT_ANCHOR_PORT(anchor_text)) << " colspan='" << input_cnt << "'>"
<< anchor_text << "</td>";
}

if (!output_anchors.empty()) {
out_ << "</tr>\n";
}
out_ << TAB << TAB << "</table>\n" << TAB << ">];\n";
}

void GraphDebugPrinter::DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag) {
if (node == nullptr) {
GELOGI("Some nodes are null.");
return;
}
auto all_out_anchor = node->GetAllOutDataAnchors();
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return );
for (const auto &anchor : all_out_anchor) {
auto src_anchor = anchor;
auto src_node_name = node->GetName();
auto src_anchor_index = op_desc->GetOutputNameByIndex(static_cast<uint32_t>(src_anchor->GetIdx()));
auto des_anchors = anchor->GetPeerAnchors();
for (const auto &peer_in_anchor : des_anchors) {
auto in_data_anchor = Anchor::DynamicAnchorCast<InDataAnchor>(peer_in_anchor);
std::string dst_node_name;
out_ << TAB << STR_FMT(src_node_name);
out_ << ":" << OUTPUT_ANCHOR_PORT(src_anchor_index);
auto op = peer_in_anchor->GetOwnerNode()->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op, continue);
if (in_data_anchor != nullptr) {
dst_node_name = in_data_anchor->GetOwnerNode()->GetName();
string des_anchor_index = op->GetInputNameByIndex(static_cast<uint32_t>(in_data_anchor->GetIdx()));
out_ << " -> " << STR_FMT(dst_node_name);
out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index);
out_ << "[";
}
auto in_control_anchor = Anchor::DynamicAnchorCast<InControlAnchor>(peer_in_anchor);
if (in_control_anchor != nullptr) {
dst_node_name = in_control_anchor->GetOwnerNode()->GetName();
string des_anchor_index = "ctrl";
out_ << " -> " << STR_FMT(dst_node_name);
out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index);
out_ << "[";
out_ << " style=dashed ";
}
if (flag != DOT_NOT_SHOW_EDGE_LABEL && in_data_anchor) {
string label;
auto src_ops = src_anchor->GetOwnerNode()->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(src_ops, return );
auto src_shape = src_ops->GetOutputDesc(src_anchor->GetIdx()).GetShape();
auto dim = src_shape.GetDims();
std::ostringstream tensor_info;
if (dim.size() > 0) {
for (size_t i = 0; i < dim.size(); i++) {
if (i != dim.size() - 1) {
tensor_info << dim[i] << "x";
} else {
tensor_info << dim[i];
}
}
} else {
tensor_info << "?";
}
auto src_tensor_desc = src_ops->GetOutputDescPtr(src_anchor->GetIdx());
GE_CHECK_NOTNULL_EXEC(src_tensor_desc, return );
auto format = src_tensor_desc->GetFormat();
auto datatype = src_tensor_desc->GetDataType();
tensor_info << " : " << formats[format] << " : " << types[datatype];
label = tensor_info.str();
out_ << "label=" << STR_FMT(label);
}
out_ << "]" << std::endl;
}
}
}

graphStatus GraphDebugPrinter::DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name,
uint32_t flag) {
auto compute_graph = GraphUtils::GetComputeGraph(graph);
if (compute_graph == nullptr) {
GELOGI("Compute graph is NULL .");
return GRAPH_SUCCESS;
}
return DumpGraphDotFile(compute_graph, output_dot_file_name, flag);
}

graphStatus GraphDebugPrinter::DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name,
uint32_t flag) {
if (graph == nullptr) {
GELOGI("graph is null.");
return GRAPH_SUCCESS;
}
std::ostringstream out_;
out_ << "digraph G{\n";
out_ << TAB << R"(ratio=compress;size="8, 100")" << std::endl;
out_ << TAB << R"(node[fontname="Consolas"])" << std::endl;
out_ << TAB << R"(edge[fontsize = "8" fontname = "Consolas" color="dimgray" ])" << std::endl;
auto all_nodes = graph->GetAllNodes();
for (const auto &node : all_nodes) {
for (const auto &temp : node->GetAllOutDataAnchors()) {
for (const auto &peer : temp->GetPeerAnchors()) {
auto temp_control_anchor = Anchor::DynamicAnchorCast<InControlAnchor>(peer);
if (temp_control_anchor) {
(void)control_anchor.insert(peer->GetOwnerNode()->GetName());
}
}
}
}
for (const auto &node : all_nodes) {
DumpNodeToDot(node, out_);
}
for (const auto &node : all_nodes) {
DumpEdgeToDot(node, out_, flag);
}
out_ << "}";
std::ofstream output_file(output_dot_file_name);
if (output_file.is_open()) {
output_file << out_.str();
} else {
GELOGW("%s open error.", output_dot_file_name.c_str());
}
return GRAPH_SUCCESS;
}
} // namespace ge

+ 0
- 48
metadef/graph/debug/graph_debug.h View File

@@ -1,48 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_
#define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_
#include <cstdint>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include "external/graph/graph.h"
#include "./ge_error_codes.h"
#include "graph/compute_graph.h"
#include "graph/debug/ge_log.h"
#include "graph/node.h"
#include "utils/graph_utils.h"

namespace ge {
enum DotFileFlag {
// Show nodes, edges, size, type and format
DOT_FLAG_DEFAULT = 0,
DOT_NOT_SHOW_EDGE_LABEL = 1,
};
class GraphDebugPrinter {
public:
static graphStatus DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name,
uint32_t flag = DOT_FLAG_DEFAULT);
static graphStatus DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name,
uint32_t flag = DOT_FLAG_DEFAULT);
static void DumpNodeToDot(const NodePtr node, std::ostringstream &out_);
static void DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag = DOT_FLAG_DEFAULT);
};
} // namespace ge

#endif // COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_

+ 0
- 241
metadef/graph/detail/attributes_holder.cc View File

@@ -1,241 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "detail/attributes_holder.h"
#include <map>
#include "debug/ge_log.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/ge_attr_value.h"
#include "proto/ge_ir.pb.h"

namespace ge {
using std::map;
using std::unordered_set;
void AttrHolder::CopyAttrsFrom(const AttrHolder &holder) { MutableAttrMap().CopyValueFrom(holder.GetAttrMap()); }
graphStatus AttrHolder::SetAttr(const std::string &name, const GeAttrValue &value) {
if (value.IsEmpty()) {
GELOGE(GRAPH_FAILED, "value is empty, key %s", name.c_str());
return GRAPH_FAILED;
}
auto proto_map = MutableAttrMap().GetProtoMsg();
auto proto_val = value.value_.GetProtoMsg();
if (proto_map == nullptr || proto_val == nullptr) {
return GRAPH_FAILED;
}
auto it = proto_map->find(name);
if (it != proto_map->end()) {
if (it->second.value_case() != proto::AttrDef::VALUE_NOT_SET &&
it->second.value_case() != proto_val->value_case()) {
return GRAPH_FAILED;
}
}
(*proto_map)[name] = *proto_val;
return GRAPH_SUCCESS;
}

graphStatus AttrHolder::AddRequiredAttr(const std::string &name) {
if (HasAttr(name)) {
return GRAPH_FAILED;
}
requiredAttrs_.push_back(name);
return GRAPH_SUCCESS;
}

graphStatus AttrHolder::GetAttr(const std::string &name, GeAttrValue &value) const {
auto proto_map = GetAttrMap().GetProtoMsg();
auto proto_val = value.value_.GetProtoMsg();
if (proto_map == nullptr || proto_val == nullptr) {
return GRAPH_FAILED;
}
auto it = proto_map->find(name);
if (it != proto_map->end()) {
*proto_val = it->second;
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}

bool AttrHolder::HasAttr(const std::string &name) const {
auto proto_map = GetAttrMap().GetProtoMsg();
if (proto_map != nullptr) {
if (proto_map->find(name) != proto_map->end()) {
return true;
}
}
return std::find(requiredAttrs_.begin(), requiredAttrs_.end(), name) != requiredAttrs_.end();
}

graphStatus AttrHolder::DelAttr(const std::string &name) {
auto proto_map = MutableAttrMap().GetProtoMsg();
if (proto_map == nullptr) {
return GRAPH_FAILED;
}
auto it = proto_map->find(name);
if (it != proto_map->end()) {
(void)proto_map->erase(it);
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}

const std::map<string, GeAttrValue> AttrHolder::GetAllAttrs() const {
std::map<string, GeAttrValue> attr_value_map;
auto proto_map = GetAttrMap().GetProtoMsg();
if (proto_map != nullptr) {
auto proto_owner = GetAttrMap().GetProtoOwner();
GE_CHK_BOOL_EXEC(proto_owner != nullptr, return attr_value_map, "proto_owner is nullptr");
for (const auto &it : *proto_map) {
attr_value_map[it.first] = GeAttrValue(proto_owner, const_cast<proto::AttrDef *>(&it.second));
}
}
return attr_value_map;
}

const std::unordered_set<string> AttrHolder::GetAllAttrNames() const {
std::unordered_set<string> names;
auto proto_map = GetAttrMap().GetProtoMsg();
if (proto_map != nullptr) {
for (const auto &it : *proto_map) {
(void)names.insert(it.first);
}
}
for (const string &it : requiredAttrs_) {
(void)names.insert(it);
}
return names;
}

template <>
void GeIrProtoHelper<proto::AttrDef>::InitDefault() {
std::shared_ptr<proto::AttrDef> proto_owner;
proto_owner = ComGraphMakeShared<proto::AttrDef>();
if (proto_owner == nullptr) {
GELOGE(GRAPH_FAILED, "proto::AttrDef make shared failed");
return;
}
protoMsg_ = proto_owner.get();
protoOwner_ = proto_owner;
}

template <>
void GeIrProtoHelper<proto::TensorDef>::InitDefault() {
std::shared_ptr<proto::TensorDef> proto_owner;
proto_owner = ComGraphMakeShared<proto::TensorDef>();
if (proto_owner == nullptr) {
GELOGE(GRAPH_FAILED, "proto::TensorDef make shared failed");
return;
}
protoMsg_ = proto_owner.get();
protoOwner_ = proto_owner;
}

template <>
void GeIrProtoHelper<proto::TensorDescriptor>::InitDefault() {
std::shared_ptr<proto::TensorDescriptor> proto_owner;
proto_owner = ComGraphMakeShared<proto::TensorDescriptor>();
if (proto_owner == nullptr) {
GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed");
return;
}
protoMsg_ = proto_owner.get();
protoOwner_ = proto_owner;
}

template <>
void GeIrProtoHelper<proto::ShapeDef>::InitDefault() {
std::shared_ptr<proto::ShapeDef> proto_owner;
proto_owner = ComGraphMakeShared<proto::ShapeDef>();
if (proto_owner == nullptr) {
GELOGE(GRAPH_FAILED, "proto::ShapeDef make shared failed");
return;
}
protoMsg_ = proto_owner.get();
protoOwner_ = proto_owner;
}

template <>
void GeIrProtoHelper<proto::NamedAttrs>::InitDefault() {
std::shared_ptr<proto::NamedAttrs> proto_owner;
proto_owner = ComGraphMakeShared<proto::NamedAttrs>();
if (proto_owner == nullptr) {
GELOGE(GRAPH_FAILED, "proto::NamedAttrs make shared failed");
return;
}
protoMsg_ = proto_owner.get();
protoOwner_ = proto_owner;
}

template <>
void GeIrProtoHelper<proto::ModelDef>::InitDefault() {
std::shared_ptr<proto::ModelDef> proto_owner;
proto_owner = ComGraphMakeShared<proto::ModelDef>();
if (proto_owner == nullptr) {
GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed");
return;
}
protoMsg_ = proto_owner.get();
protoOwner_ = proto_owner;
}

template <>
void GeIrProtoHelper<proto::OpDef>::InitDefault() {
std::shared_ptr<proto::OpDef> proto_owner;
proto_owner = ComGraphMakeShared<proto::OpDef>();
if (proto_owner == nullptr) {
GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
return;
}
protoMsg_ = proto_owner.get();
protoOwner_ = proto_owner;
}

template <>
void GeIrProtoHelper<proto::GraphDef>::InitDefault() {
std::shared_ptr<proto::GraphDef> proto_owner;
proto_owner = ComGraphMakeShared<proto::GraphDef>();
if (proto_owner == nullptr) {
GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
return;
}
protoMsg_ = proto_owner.get();
protoOwner_ = proto_owner;
}

template <>
void GeIrProtoHelper<ProtoAttrMap>::InitDefault() {
std::shared_ptr<proto::TensorDescriptor> proto_owner;
proto_owner = ComGraphMakeShared<proto::TensorDescriptor>();
if (proto_owner == nullptr) {
GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed");
return;
}
protoMsg_ = proto_owner->mutable_attr();
protoOwner_ = proto_owner;
}

template <>
void GeIrProtoHelper<const ProtoAttrMap>::InitDefault() {
std::shared_ptr<proto::TensorDescriptor> proto_owner;
proto_owner = ComGraphMakeShared<proto::TensorDescriptor>();
if (proto_owner == nullptr) {
GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed");
return;
}
protoMsg_ = &proto_owner->attr();
protoOwner_ = proto_owner;
}
} // namespace ge

+ 0
- 508
metadef/graph/format_refiner.cc View File

@@ -1,508 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "format_refiner.h"

#include <deque>
#include <iostream>
#include <set>
#include <unordered_map>
#include <unordered_set>

#include "graph/ref_relation.h"
#include "./compute_graph.h"
#include "./ge_error_codes.h"
#include "./graph/ge_tensor.h"
#include "./operator.h"
#include "./operator_factory.h"
#include "debug/ge_log.h"
#include "debug/ge_op_types.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "utils/node_utils.h"
#include "utils/op_desc_utils.h"
#include "utils/tensor_utils.h"
#include "utils/type_utils.h"

using namespace ge;
using namespace std;
namespace ge {
namespace {
const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE};
const string kIsGraphInferred = "_is_graph_inferred";
thread_local RefRelations reflection_builder;
} // namespace

graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection,
std::deque<ge::NodePtr> &nodes, ge::Format to_be_set_format) {
for (const auto &cell : reflection) {
auto node = cell.node;
auto in_out_idx = cell.in_out_idx;
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
if (cell.in_out == ge::NODE_IN) {
auto desc = node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(in_out_idx));
desc.SetOriginFormat(to_be_set_format);
desc.SetFormat(to_be_set_format);
(void)node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(in_out_idx), desc);
} else {
auto desc = node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(in_out_idx));
desc.SetOriginFormat(to_be_set_format);
desc.SetFormat(to_be_set_format);
(void)node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(in_out_idx), desc);
}
nodes.push_back(cell.node);
}

return GRAPH_SUCCESS;
}

graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) {
// 5 meas dim num
if (node_ptr->GetType() != "BiasAdd") {
return GRAPH_SUCCESS;
}
std::unordered_map<string, Format> kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}};
for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) {
auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i);
GE_CHECK_NOTNULL(in_desc);
if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num
continue;
}
auto format = in_desc->GetOriginFormat();
auto key = TypeUtils::FormatToSerialString(format);
auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key];
in_desc->SetOriginFormat(fixed_format);
in_desc->SetFormat(fixed_format);
GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i,
node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::FormatToSerialString(fixed_format).c_str());
}
for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) {
auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i);
GE_CHECK_NOTNULL(out_desc);
if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num
continue;
}
auto format = out_desc->GetOriginFormat();
auto key = TypeUtils::FormatToSerialString(format);
auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key];
out_desc->SetOriginFormat(fixed_format);
out_desc->SetFormat(fixed_format);
GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i,
node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::FormatToSerialString(fixed_format).c_str());
}
return GRAPH_SUCCESS;
}

graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) {
GE_CHECK_NOTNULL(graph);
GE_CHECK_NOTNULL(op_desc);
if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) {
ConstGeTensorPtr tensor_value;
if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) {
GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str());
return GRAPH_FAILED;
}
GE_CHECK_NOTNULL(tensor_value);
(void)op_desc->UpdateOutputDesc(0, tensor_value->GetTensorDesc());
}
return GRAPH_SUCCESS;
}

graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points,
std::vector<ge::NodePtr> &data_nodes,
std::unordered_map<ge::NodePtr, bool> &node_status) {
if (graph == nullptr) {
GELOGE(GRAPH_FAILED, "input graph is null");
return GRAPH_FAILED;
}
anchor_points.clear();
// Get all anchor point nodes and switch nodes
for (auto &node_ptr : graph->GetAllNodes()) {
if (node_ptr == nullptr) {
return GRAPH_FAILED;
}
auto op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
return GRAPH_FAILED;
}
graphStatus status = RefreshConstantOutProcess(graph, op_desc);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "refresh constant out process failed!");
return GRAPH_FAILED;
}
// consider special node save process
// get all input desc format
bool node_is_all_nd = false;
auto input_size = static_cast<uint32_t>(op_desc->GetAllInputsSize());
for (uint32_t i = 0; i < input_size; i++) {
// Operator pre-set format but not origin format
GE_IF_BOOL_EXEC(op_desc->MutableInputDesc(i) == nullptr, continue);
auto input_format = op_desc->MutableInputDesc(i)->GetFormat();
// Pre-save data node (only main graph data) and default infer fail
if (node_ptr->GetType() == DATA) {
data_nodes.push_back(node_ptr);
}
if (input_format != FORMAT_ND && input_format != FORMAT_RESERVED) {
node_is_all_nd = true;
}
}
// Get all output desc format
auto output_size = static_cast<uint32_t>(op_desc->GetOutputsSize());
for (uint32_t i = 0; i < output_size; i++) {
GE_IF_BOOL_EXEC(op_desc->MutableOutputDesc(i) == nullptr, continue);
auto output_format = op_desc->MutableOutputDesc(i)->GetFormat();
if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) {
node_is_all_nd = true;
}
}
// check anchor point valid
if (!node_is_all_nd) {
continue;
}
// special process for biasAdd op
// In tensorflow, biasAdd's format is alwayse NHWC even though set the arg
// "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism
// so here do special process
status = BiasAddFormatFixProcess(node_ptr);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "fix biasAdd process failed!");
return GRAPH_FAILED;
}

GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str());
anchor_points.push_back(node_ptr);
}
GELOGI("anchor_points number is %zu", anchor_points.size());
return GRAPH_SUCCESS;
}
graphStatus FormatRefiner::AnchorProcess(const ge::NodePtr &anchor_node,
std::unordered_map<ge::NodePtr, bool> &node_status) {
if (anchor_node == nullptr) {
GELOGE(GRAPH_FAILED, "anchor node is null!");
return GRAPH_FAILED;
}
std::deque<ge::NodePtr> nodes;
nodes.push_back(anchor_node);
while (!nodes.empty()) {
ge::NodePtr node = nodes.front();
nodes.pop_front();
graphStatus status = BackInferProcess(nodes, node, node_status);
if (status != GRAPH_SUCCESS && node != nullptr) {
GELOGE(status, "BackInferProcess failed!node name [%s]", node->GetName().c_str());
return status;
}
status = ForwardInferProcess(nodes, node, node_status);
if (status != GRAPH_SUCCESS && node != nullptr) {
GELOGE(status, "ForwardInferProcess failed!node name [%s]", node->GetName().c_str());
return status;
}
}
return GRAPH_SUCCESS;
}
graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node,
std::unordered_map<ge::NodePtr, bool> &node_status) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());

GELOGD("Enter back infer process!Node is [%s]", (node->GetName()).c_str());
for (const auto &in_anchor : node->GetAllInDataAnchors()) {
GELOGD("Node is [%s] [B]", (node->GetName()).c_str());
auto in_data_anchor_idx = in_anchor->GetIdx();
auto input_desc = node->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx));
GE_IF_BOOL_EXEC(input_desc == nullptr, continue);
auto to_be_set_format = input_desc->GetOriginFormat();
if (to_be_set_format == FORMAT_ND) {
GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str());
continue;
}
auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
if (peer_out_data_anchor == nullptr) {
GELOGW("Node[%s] %dth in data anchor's peer_out_anchor is null", (node->GetName()).c_str(), in_data_anchor_idx);
continue;
}
auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (node->GetName()).c_str());
continue;
}
// Check format whether have been set
int idx = peer_out_data_anchor->GetIdx();
// do peer_out_node name and index as key to lookup reflections
ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx);
std::unordered_set<RefCell, RefCellHash> reflection;
auto status = reflection_builder.LookUpRefRelations(key, reflection);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d out edge",
(peer_out_data_node->GetName()).c_str(), idx);
return GRAPH_FAILED;
}

auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx));
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) {
auto dim_num = ge_tensor_desc.GetShape().GetDimNum();
if (dim_num == 0) {
GELOGD("node name:%s idx:%d out is scalar. stop back infer!", peer_out_data_node->GetName().c_str(), idx);
continue;
}
/// Check whether node to change dims ()
/// Because some node will calculate with 5D, C dim maybe multi meaning
auto peer_out_data_node_type = peer_out_data_node->GetType();
auto iter1 = kChangeDimNodes.find(peer_out_data_node_type);
// 4 means dims num
if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) {
GELOGD("Node[%s] is change dim node and shape is smaller than 4. do not modify format",
(peer_out_data_node->GetName()).c_str());
continue;
}

if (reflection.empty()) {
ge_tensor_desc.SetOriginFormat(to_be_set_format);
ge_tensor_desc.SetFormat(to_be_set_format);
(void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc);

// Call operator infer format api (forward) to get out format
GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str());
status = peer_out_data_node->InferOriginFormat();
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str());
return GRAPH_FAILED;
}
nodes.push_back(peer_out_data_node);
} else {
auto status = ReflectionProcess(reflection, nodes, to_be_set_format);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "reflection process failed!");
return GRAPH_FAILED;
}
}
}
}
return GRAPH_SUCCESS;
}
graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node,
std::unordered_map<ge::NodePtr, bool> &node_status) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
GELOGD("Enter forward infer process!Node is [%s]", (node->GetName()).c_str());
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
GELOGD("Node is [%s] [F]", (node->GetName()).c_str());
GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue);
auto out_data_anchor_idx = out_data_anchor->GetIdx();
auto to_be_set_format =
node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(out_data_anchor_idx))->GetOriginFormat();
if (to_be_set_format == FORMAT_ND) {
GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str());
continue;
}
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue);

auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode();
GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue);
GE_IF_BOOL_EXEC(peer_in_data_node->GetOpDesc() == nullptr, continue);

// Check format whether have been set
int idx = peer_in_data_anchor->GetIdx();
// do peer_out_node name and index as key to lookup reflections
ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx);
std::unordered_set<RefCell, RefCellHash> reflection;
auto status = reflection_builder.LookUpRefRelations(key, reflection);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d input edge",
(peer_in_data_node->GetName()).c_str(), idx);
return GRAPH_FAILED;
}
auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx));
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) {
auto dim_num = ge_tensor_desc.GetShape().GetDimNum();
if (dim_num == 0) {
GELOGI("node name:%s idx:%d in is scalar. stop forward infer!", peer_in_data_node->GetName().c_str(), idx);
continue;
}
/// Check whether node to change dims ()
/// Because some node will calculate with 5D, C dim maybe multi meaning
auto peer_in_data_node_type = peer_in_data_node->GetType();
auto iter1 = kChangeDimNodes.find(peer_in_data_node_type);
// 4 means dims num
if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) {
GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str());
continue;
}

if (reflection.empty()) {
ge_tensor_desc.SetOriginFormat(to_be_set_format);
ge_tensor_desc.SetFormat(to_be_set_format);
(void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(idx), ge_tensor_desc);

/// Because netoutput node added before infer format ,so netoutput is end condition
/// must set netoutput format , because saved result depend on format
if (peer_in_data_node_type == NETOUTPUT) {
continue;
}

// Call operator infer format api (forward) to get out format
GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str());
status = peer_in_data_node->InferOriginFormat();
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str());
return GRAPH_FAILED;
}
nodes.push_back(peer_in_data_node);
} else {
auto status = ReflectionProcess(reflection, nodes, to_be_set_format);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "reflection process failed!");
return GRAPH_FAILED;
}
}
}
}
}
return GRAPH_SUCCESS;
}

void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor_points) {
for (const auto &node : anchor_points) {
if (node == nullptr || node->GetOpDesc() == nullptr) {
continue;
}
for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) {
if (input_desc != nullptr) {
input_desc->SetOriginFormat(input_desc->GetFormat());
}
}
for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) {
if (output_desc != nullptr) {
output_desc->SetOriginFormat(output_desc->GetFormat());
}
}
}
}

graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes,
ge::Format data_format,
std::unordered_map<ge::NodePtr, bool> &node_status) {
if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) {
GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph),
TypeUtils::FormatToSerialString(data_format).c_str());
return GRAPH_SUCCESS;
}
GELOGD("Enter DataNodeFormatProcess");
std::vector<ge::NodePtr> uninfered_data_nodes;
// Check and renew data nodes format
for (const auto &data_node : data_nodes) {
GE_CHECK_NOTNULL(data_node);
auto op_desc = data_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(0));
auto curr_format = op_desc->GetOutputDescPtr(0)->GetOriginFormat();
if (curr_format != FORMAT_ND) {
// Data format has been infered , continue
continue;
}
// Set format for un-infered data node
auto input_descs = op_desc->GetAllInputsDescPtr();
auto output_descs = op_desc->GetAllOutputsDescPtr();

for (const auto &input_desc : input_descs) {
if (input_desc != nullptr) {
input_desc->SetOriginFormat(data_format);
input_desc->SetFormat(data_format);
}
}
for (const auto &output_desc : output_descs) {
if (output_desc != nullptr) {
output_desc->SetOriginFormat(data_format);
output_desc->SetFormat(data_format);
}
}
uninfered_data_nodes.push_back(data_node);
}
// Reinfer format from uninfered data nodes
for (const auto &node : uninfered_data_nodes) {
if (node == nullptr) {
continue;
}
GELOGD("data node [%s] start infer format process", node->GetName().c_str());
auto status = AnchorProcess(node, node_status);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "data node [%s] infer format process failed!", node->GetName().c_str());
return GRAPH_FAILED;
}
}
GELOGD("DataNodeFormatProcess success");
return GRAPH_SUCCESS;
}

graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) {
GELOGI("Enter InferOrigineFormat process!");

// True: infered false:no-infered
std::unordered_map<ge::NodePtr, bool> node_status;
std::vector<ge::NodePtr> anchor_points;
std::vector<ge::NodePtr> data_nodes;
// global net format

if (graph == nullptr) {
GELOGE(GRAPH_FAILED, "input graph is null");
return GRAPH_FAILED;
}
// build reflection relations of boundary
(void)reflection_builder.Clear();
auto status = reflection_builder.BuildRefRelations(*graph);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "build reflection relations failed for main and subgraph!");
return GRAPH_FAILED;
}
// User set global net format
status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!");
return GRAPH_FAILED;
}
// Refresh origin format of anchor point
RefreshOriginFormatOfAnchor(anchor_points);
// Infer format process
for (const auto &anchor_node : anchor_points) {
if (anchor_node == nullptr) {
continue;
}
status = AnchorProcess(anchor_node, node_status);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Anchor node [%s] process failed!", anchor_node->GetName().c_str());
return GRAPH_FAILED;
}
}
/// According to discuss with sys-enginer, data node default format is ND.Its format
/// should be set by infered.But if some data-node can not be got by infer, set context's
/// format for these data nodes.
/// Notice: ignore 5D formats
auto data_format = graph->GetDataFormat();
status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status);

(void)AttrUtils::SetBool(graph, kIsGraphInferred, true);

return status;
}

bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) {
bool is_graph_inferred = false;
return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred);
}
} // namespace ge

+ 0
- 50
metadef/graph/format_refiner.h View File

@@ -1,50 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMMON_GRAPH_FORMAT_REFINER_H_
#define COMMON_GRAPH_FORMAT_REFINER_H_

#include <deque>
#include <string>
#include <unordered_map>
#include <vector>
#include "./compute_graph.h"
#include "./external/graph/types.h"
#include "./ge_error_codes.h"

namespace ge {
// ShapeRefiner performs shape inference for compute graphs
class FormatRefiner {
public:
static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph);

private:
static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc);
static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points,
std::vector<ge::NodePtr> &data_nodes,
std::unordered_map<ge::NodePtr, bool> &node_status);
static graphStatus AnchorProcess(const ge::NodePtr &anchor_node, std::unordered_map<ge::NodePtr, bool> &node_status);
static void RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor_points);
static graphStatus BackInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node,
std::unordered_map<ge::NodePtr, bool> &node_status);
static graphStatus ForwardInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node,
std::unordered_map<ge::NodePtr, bool> &node_status);
static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes,
ge::Format data_format, std::unordered_map<ge::NodePtr, bool> &node_status);
static bool IsGraphInferred(const ComputeGraphPtr &graph);
};
} // namespace ge
#endif // COMMON_GRAPH_FORMAT_REFINER_H_

+ 0
- 1078
metadef/graph/ge_attr_define.cc
File diff suppressed because it is too large
View File


+ 0
- 1289
metadef/graph/ge_attr_value.cc
File diff suppressed because it is too large
View File


+ 0
- 1021
metadef/graph/ge_tensor.cc
File diff suppressed because it is too large
View File


+ 0
- 384
metadef/graph/graph.cc View File

@@ -1,384 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "external/graph/graph.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/debug/ge_op_types.h"
#include "graph/model.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"

using std::map;
using std::pair;
using std::string;
using std::vector;

namespace ge {
class GraphImpl {
public:
friend class GraphUtils;
GraphImpl(const GraphImpl &) = delete;
GraphImpl &operator=(const GraphImpl &) = delete;

explicit GraphImpl(const std::string &name) : name_(name) {}

~GraphImpl() {
if (IsValid()) {
if (compute_graph_ != nullptr) {
GraphUtils::BreakConnect(compute_graph_->GetAllNodesInfo());
}
}
for (const auto &it : op_list_) {
Operator op = it.second;
op.BreakConnect();
}
}

graphStatus SetInputs(const std::vector<Operator> &inputs) {
compute_graph_ = GraphUtils::CreateGraphFromOperator(name_, inputs);
GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "Build Graph failed.");
GE_CHK_BOOL_RET_STATUS(inputs.size() != 0, GRAPH_FAILED, "set input NULL.");
compute_graph_->SetInputSize(static_cast<uint32_t>(inputs.size()));
return GRAPH_SUCCESS;
}

graphStatus SetOutputs(const std::vector<Operator> &outputs) {
if (compute_graph_ == nullptr) {
GELOGE(GRAPH_FAILED, "set ComputeGraph failed.");
return GRAPH_FAILED;
}
if (outputs.empty()) {
GELOGW("set outputs size is 0.");
return GRAPH_SUCCESS;
}

// Construct special output node
std::vector<std::pair<Operator, std::vector<size_t>>> output_indexs;
for (size_t i = 0; i < outputs.size(); ++i) {
output_indexs.emplace_back(outputs[i], std::vector<size_t>{});
}

graphStatus ret = SetOutputs(output_indexs);
return ret;
}

graphStatus SetOutputs(const std::vector<std::pair<Operator, std::vector<size_t>>> &output_indexs) {
if (compute_graph_ == nullptr) {
GELOGE(GRAPH_FAILED, "set ComputeGraph failed.");
return GRAPH_FAILED;
}
if (output_indexs.empty()) {
GELOGW("set outputs size is 0.");
return GRAPH_SUCCESS;
}

// Construct special output node
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes;
for (const auto &item : output_indexs) {
const Operator &output = item.first;
const vector<size_t> &indexs = item.second;
ge::NodePtr node = compute_graph_->FindNode(output.GetName());
if (node == nullptr) {
GELOGW("user designated out_node [%s] not exist in graph, will ignored!", output.GetName().c_str());
continue;
}

ge::OpDescPtr tmp_op_ptr = node->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue);
size_t out_size = tmp_op_ptr->GetOutputsSize();
if (indexs.empty()) {
for (size_t i = 0; i < out_size; ++i) {
output_name_ += output.GetName() + ":" + std::to_string(i) + ";";
output_nodes.emplace_back(node, i);
}
} else {
for (size_t i = 0; i < indexs.size(); ++i) {
if (indexs[i] >= out_size) {
GELOGW("index[%zu] is not belong to out_node[%s]", indexs[i], output.GetName().c_str());
} else {
output_name_ += output.GetName() + ":" + std::to_string(i) + ";";
output_nodes.emplace_back(node, indexs[i]);
}
}
}
}

// Del last ";"
if (!output_name_.empty()) {
output_name_ = output_name_.substr(0, output_name_.length() - 1);
}
compute_graph_->SetUserDefOutput(output_name_);
compute_graph_->SetOutputSize(static_cast<uint32_t>(output_indexs.size()));
compute_graph_->SetGraphOutNodesInfo(output_nodes);
return GRAPH_SUCCESS;
}

graphStatus SetOutputs(const std::vector<pair<Operator, string>> &outputs) {
GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild.");
GE_CHK_BOOL_EXEC_INFO(outputs.size() != 0, return GRAPH_SUCCESS, "set outputs size is 0.");

// Construct specified output
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes;
for (auto item : outputs) {
ge::NodePtr node = compute_graph_->FindNode(item.first.GetName());
if (node == nullptr) {
GELOGE(GRAPH_FAILED, " Warning, user designated out_node (%s) not exist in graph, this out_node ignored!",
item.first.GetName().c_str());
return GRAPH_FAILED;
}
ge::OpDescPtr tmp_op_ptr = node->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue);
size_t out_size = tmp_op_ptr->GetOutputsSize();

if (item.second.empty()) {
for (size_t i = 0; i < out_size; ++i) {
output_name_ += item.first.GetName() + ":" + std::to_string(i) + ";";
output_nodes.push_back(std::make_pair(node, i));
}
} else {
int32_t index = tmp_op_ptr->GetOutputIndexByName(item.second);
if (index < 0) {
GELOGE(GRAPH_FAILED,
" Warning, user designated out_node (%s):(%s) not exist in graph, this out_node ignored!",
item.first.GetName().c_str(), item.second.c_str());
return GRAPH_FAILED;
}
output_name_ += item.first.GetName() + ":" + std::to_string(index) + ";";
output_nodes.push_back(std::make_pair(node, index));
}
}
// Del last ";"
if (!output_name_.empty()) {
output_name_ = output_name_.substr(0, output_name_.length() - 1);
}
compute_graph_->SetOutputSize(static_cast<uint32_t>(outputs.size()));
compute_graph_->SetGraphOutNodesInfo(output_nodes);
GELOGI("********************SetOutputs Success***********************");
GE_IF_BOOL_EXEC(!output_name_.empty(), GELOGI(" NetOutputs: (%s)", output_name_.c_str()));

return GRAPH_SUCCESS;
}

graphStatus SetTargets(const std::vector<Operator> &targets) {
GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild.");
GE_CHK_BOOL_EXEC_INFO(targets.size() != 0, return GRAPH_SUCCESS, "set targets size is 0.");

std::vector<ge::NodePtr> target_nodes;
for (auto item : targets) {
ge::NodePtr node = compute_graph_->FindNode(item.GetName());
if (node == nullptr) {
GELOGW(" Warning, user designated target_node (%s) not exist in graph, this target_node ignored!",
item.GetName().c_str());
continue;
}
target_nodes.push_back(node);
}
compute_graph_->SetGraphTargetNodesInfo(target_nodes);
return GRAPH_SUCCESS;
}
bool IsValid() const { return (compute_graph_ != nullptr); }

graphStatus AddOp(const ge::Operator &op) {
std::pair<std::map<string, ge::Operator>::iterator, bool> ret;
ret = op_list_.emplace(std::pair<string, ge::Operator>(op.GetName(), op));
GE_CHK_BOOL_RET_STATUS(ret.second != false, GRAPH_FAILED, "the op have added before, op name:%s.",
op.GetName().c_str());
return GRAPH_SUCCESS;
}

graphStatus GetAllOpName(std::vector<string> &op_name) const {
for (const auto &it : op_list_) {
op_name.push_back(it.second.GetName());
}
return GRAPH_SUCCESS;
}

graphStatus FindOpByName(const string &name, ge::Operator &op) const {
auto it = op_list_.find(name);
GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str());
op = it->second;
return GRAPH_SUCCESS;
}

graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const {
for (auto &op : op_list_) {
auto op_type = op.second.GetOpType();
if (op_type == type) {
ops.push_back(op.second);
continue;
}
if (op_type == ge::FRAMEWORKOP) {
op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type);
if (op_type == type) {
ops.push_back(op.second);
}
}
}
return GRAPH_SUCCESS;
}

void SetNeedIteration(bool need_iteration) {
if (compute_graph_ == nullptr) {
GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null.");
return;
}
compute_graph_->SetNeedIteration(need_iteration);
}

const std::string &GetName() const { return name_; }

private:
std::string name_;
std::string output_name_;
std::map<string, ge::Operator> op_list_;
ComputeGraphPtr compute_graph_{nullptr};
};

Graph::Graph(const std::string &name) {
impl_ = ComGraphMakeShared<GraphImpl>(name);
if (impl_ == nullptr) {
GELOGW("GraphImpl make shared failed, impl_ is nullptr");
}
}

graphStatus Graph::AddOp(const ge::Operator &op) {
GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, "AddOp failed: graph can not be used, impl is nullptr.");
return impl_->AddOp(op);
}

graphStatus Graph::GetAllOpName(std::vector<string> &op_name) const {
GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED,
"GetAllOpName failed: graph can not be used, impl is nullptr.");
return impl_->GetAllOpName(op_name);
}

graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const {
Operator op_find_op_def("NULL");
op = op_find_op_def;
GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED,
"FindOpByName failed: graph can not be used, impl is nullptr.");
return impl_->FindOpByName(name, op);
}

graphStatus Graph::FindOpByType(const string &type, std::vector<ge::Operator> &ops) const {
GE_CHECK_NOTNULL(impl_);
return impl_->FindOpByType(type, ops);
}

Graph &Graph::SetInputs(const vector<ge::Operator> &inputs) {
GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.")
GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0.");
(void)impl_->SetInputs(inputs);
return *this;
}

Graph &Graph::SetOutputs(const vector<ge::Operator> &outputs) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr.");
return *this;
}
(void)impl_->SetOutputs(outputs);
return *this;
}

Graph &Graph::SetOutputs(const std::vector<std::pair<Operator, std::vector<size_t>>> &output_indexs) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr.");
return *this;
}
(void)impl_->SetOutputs(output_indexs);
return *this;
}

Graph &Graph::SetOutputs(const std::vector<pair<Operator, string>> &outputs) {
GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetOutputs failed: graph can not be used, impl is nullptr.")
(void)impl_->SetOutputs(outputs);
return *this;
}

Graph &Graph::SetTargets(const vector<ge::Operator> &targets) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "SetTargets failed: graph can not be used, impl is nullptr.");
return *this;
}
(void)impl_->SetTargets(targets);
return *this;
}

bool Graph::IsValid() const {
if (impl_ == nullptr) {
return false;
}
return impl_->IsValid();
}

void Graph::SetNeedIteration(bool need_iteration) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "Set need iteration failed, as impl is null.");
return;
}
impl_->SetNeedIteration(need_iteration);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::GetComputeGraph(const ge::Graph &graph) {
GE_CHK_BOOL_EXEC_NOLOG(graph.IsValid(), return nullptr);
return graph.impl_->compute_graph_;
}

graphStatus Graph::SaveToFile(const string &file_name) const {
Model model = Model();
model.SetGraph(*this);
return model.SaveToFile(file_name);
}

graphStatus Graph::LoadFromFile(const string &file_name) {
Model model = Model();
graphStatus ret = model.LoadFromFile(file_name);
if (ret != GRAPH_SUCCESS) {
return ret;
}
*this = model.GetGraph();
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string &Graph::GetName() const { return impl_->GetName(); }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph
GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) {
GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph(""));

auto name = compute_graph->GetName();
auto graph = Graph(name);

GE_CHK_BOOL_EXEC_NOLOG(graph.impl_ != nullptr, return graph);
graph.impl_->compute_graph_ = compute_graph;

return graph;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) {
GE_CHECK_NOTNULL(graph.impl_);
GE_CHECK_NOTNULL(graph.impl_->compute_graph_);

graph.impl_->op_list_.clear();
for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) {
graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node);
}
return SUCCESS;
}
} // namespace ge

+ 0
- 294
metadef/graph/graph.mk View File

@@ -1,294 +0,0 @@
LOCAL_PATH := $(call my-dir)
include $(LOCAL_PATH)/stub/Makefile
COMMON_LOCAL_SRC_FILES := \
./proto/om.proto \
./proto/ge_ir.proto \
./proto/ge_onnx.proto \
./proto/insert_op.proto \
./proto/task.proto \
./proto/fwk_adapter.proto \
./proto/op_mapping_info.proto \
./proto/dump_task.proto \
./anchor.cc \
./ge_attr_value.cc \
./attr_value.cc \
./buffer.cc \
./compute_graph.cc \
./graph.cc \
./inference_context.cc \
./shape_refiner.cc \
./format_refiner.cc \
./ref_relation.cc \
./model.cc \
./model_serialize.cc \
./node.cc \
./op_desc.cc \
./operator.cc \
./operator_factory.cc \
./operator_factory_impl.cc \
./ge_attr_define.cc \
./ge_tensor.cc \
./detail/attributes_holder.cc \
./utils/anchor_utils.cc \
./utils/tuning_utils.cc \
./utils/graph_utils.cc \
./utils/ge_ir_utils.cc \
./utils/node_utils.cc \
./utils/op_desc_utils.cc \
./utils/type_utils.cc \
./utils/tensor_utils.cc \
./tensor.cc \
./debug/graph_debug.cc \
./opsproto/opsproto_manager.cc \
../ops/op_imp.cpp \
option/ge_context.cc \
option/ge_local_context.cc \
./runtime_inference_context.cc \

COMMON_LOCAL_C_INCLUDES := \
proto/om.proto \
proto/ge_ir.proto \
proto_inner/ge_onnx.proto \
proto/insert_op.proto \
proto/task.proto \
proto/fwk_adapter.proto \
proto/op_mapping_info.proto \
proto/dump_task.proto \
inc \
inc/external \
inc/external/graph \
inc/graph \
inc/common \
common \
common/graph \
third_party/protobuf/include \
libc_sec/include \
ops/built-in/op_proto/inc \


#compiler for host
include $(CLEAR_VARS)
LOCAL_MODULE := libgraph

LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2
LOCAL_CPPFLAGS += -fexceptions

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES)

LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for host
include $(CLEAR_VARS)
LOCAL_MODULE := stub/libgraph

LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2
LOCAL_CPPFLAGS += -fexceptions

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/operator_factory.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for host
include $(CLEAR_VARS)
LOCAL_MODULE := fwk_stub/libgraph

LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2
LOCAL_CPPFLAGS += -fexceptions

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/attr_value.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/inference_context.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for device
include $(CLEAR_VARS)
LOCAL_MODULE := libgraph

LOCAL_CFLAGS += -O2

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES)

LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

ifeq ($(device_os),android)
LOCAL_LDFLAGS := -ldl
endif

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_SHARED_LIBRARY)

#compiler for device
include $(CLEAR_VARS)
LOCAL_MODULE := stub/libgraph

LOCAL_CFLAGS += -O2

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/operator_factory.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

ifeq ($(device_os),android)
LOCAL_LDFLAGS := -ldl
endif

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_SHARED_LIBRARY)

#compiler for device
include $(CLEAR_VARS)
LOCAL_MODULE := fwk_stub/libgraph

LOCAL_CFLAGS += -O2

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/attr_value.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/inference_context.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

ifeq ($(device_os),android)
LOCAL_LDFLAGS := -ldl
endif

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_SHARED_LIBRARY)

# compile for ut/st
include $(CLEAR_VARS)
LOCAL_MODULE := libgraph

LOCAL_CFLAGS +=

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES)

LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_LLT_SHARED_LIBRARY)


#compiler for host static lib
include $(CLEAR_VARS)
LOCAL_MODULE := libgraph

LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2
LOCAL_CPPFLAGS += -fexceptions

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES)

LOCAL_STATIC_LIBRARIES := \
libprotobuf \

LOCAL_SHARED_LIBRARIES := \
libc_sec \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_HOST_STATIC_LIBRARY)

#compiler for device static lib
include $(CLEAR_VARS)
LOCAL_MODULE := libgraph

LOCAL_CFLAGS += -O2

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES)

LOCAL_STATIC_LIBRARIES := \
libprotobuf \

LOCAL_SHARED_LIBRARIES := \
libc_sec \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_STATIC_LIBRARY)

+ 0
- 112
metadef/graph/inference_context.cc View File

@@ -1,112 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "external/graph/inference_context.h"
#include "debug/ge_util.h"

namespace ge {
class ShapeAndTypeImpl {
public:
ShapeAndTypeImpl() = default;
~ShapeAndTypeImpl() = default;

ShapeAndTypeImpl(const Shape &shape, DataType data_type) : shape_(shape), data_type_(data_type) {}

Shape shape_;
DataType data_type_ = DT_UNDEFINED;
};

class InferenceContextImpl {
public:
InferenceContextImpl() = default;
~InferenceContextImpl() = default;

// For deliver to op in pair, help to support dynamic shape
std::vector<std::string> marks_;
std::vector<std::vector<ShapeAndType>> input_handle_shapes_and_types_;
std::vector<std::vector<ShapeAndType>> output_handle_shapes_and_types_;
};

ShapeAndType::ShapeAndType() { shape_and_type_impl_ = ComGraphMakeShared<ShapeAndTypeImpl>(); }

ShapeAndType::ShapeAndType(const Shape &shape, DataType data_type) {
shape_and_type_impl_ = ComGraphMakeShared<ShapeAndTypeImpl>(shape, data_type);
}

void ShapeAndType::SetShape(const Shape &shape) {
if (shape_and_type_impl_ != nullptr) {
shape_and_type_impl_->shape_ = shape;
}
}

void ShapeAndType::SetType(DataType data_type) {
if (shape_and_type_impl_ != nullptr) {
shape_and_type_impl_->data_type_ = data_type;
}
}

Shape ShapeAndType::GetShape() const {
if (shape_and_type_impl_ != nullptr) {
return shape_and_type_impl_->shape_;
}
return Shape();
}

DataType ShapeAndType::GetDataType() const {
if (shape_and_type_impl_ != nullptr) {
return shape_and_type_impl_->data_type_;
}
return DT_UNDEFINED;
}

InferenceContext::InferenceContext(std::unique_ptr<InferenceContextImpl> &impl) {
inference_context_impl_ = std::move(impl);
}

std::unique_ptr<InferenceContext> InferenceContext::Create() {
std::unique_ptr<InferenceContextImpl> impl =
std::unique_ptr<InferenceContextImpl>(new (std::nothrow) InferenceContextImpl());
if (impl == nullptr) {
return nullptr;
}

return std::unique_ptr<InferenceContext>(new (std::nothrow) InferenceContext(impl));
}

void InferenceContext::SetInputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types) {
inference_context_impl_->input_handle_shapes_and_types_.swap(shapes_and_types);
}

const std::vector<std::vector<ShapeAndType>> &InferenceContext::GetInputHandleShapesAndTypes() const {
return inference_context_impl_->input_handle_shapes_and_types_;
}

const std::vector<std::vector<ShapeAndType>> &InferenceContext::GetOutputHandleShapesAndTypes() const {
return inference_context_impl_->output_handle_shapes_and_types_;
}

void InferenceContext::SetOutputHandleShapesAndTypes(const std::vector<std::vector<ShapeAndType>> &shapes_and_types) {
inference_context_impl_->output_handle_shapes_and_types_ = shapes_and_types;
}

void InferenceContext::SetOutputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types) {
inference_context_impl_->output_handle_shapes_and_types_.swap(shapes_and_types);
}

void InferenceContext::SetMarks(const std::vector<std::string> &marks) { inference_context_impl_->marks_ = marks; }

const std::vector<std::string> &InferenceContext::GetMarks() const { return inference_context_impl_->marks_; }
} // namespace ge

+ 0
- 190
metadef/graph/model.cc View File

@@ -1,190 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/model.h"
#include <fcntl.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
#include <cstring>
#include <fstream>
#include <iomanip>
#include "debug/ge_attr_define.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/model_serialize.h"
#include "proto/ge_ir.pb.h"
#include "utils/attr_utils.h"
#include "utils/ge_ir_utils.h"

using google::protobuf::io::FileInputStream;
using google::protobuf::io::FileOutputStream;
using google::protobuf::io::ZeroCopyInputStream;

namespace {
const int DEFAULT_VERSION = 1;
const int ACCESS_PERMISSION_BITS = 0400;
} // namespace

namespace ge {
void Model::Init() {
(void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0);
(void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI);
version_ = 0;
}

Model::Model() {
attrs_.InitDefault();
Init();
}

Model::Model(const string &name, const string &custom_version)
: name_(name), version_(DEFAULT_VERSION), platform_version_(custom_version) {
attrs_.InitDefault();
Init();
}

string Model::GetName() const { return name_; }

void Model::SetName(const string &name) { name_ = name; }

uint32_t Model::GetVersion() const { return version_; }

string Model::GetPlatformVersion() const { return platform_version_; }

void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; }

Graph Model::GetGraph() const { return graph_; }

graphStatus Model::Save(Buffer &buffer, bool is_dump) const {
ModelSerialize serialize;
buffer = serialize.SerializeModel(*this, is_dump);
return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED;
}

void Model::SetAttr(const ProtoAttrMapHelper &attrs) { attrs_ = attrs; }

graphStatus Model::Load(const uint8_t *data, size_t len, Model &model) {
ModelSerialize serialize;
model = serialize.UnserializeModel(data, len);
return model.IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED;
}

graphStatus Model::SaveToFile(const string &file_name) const {
Buffer buffer;
if ((*this).Save(buffer) != GRAPH_SUCCESS) {
GE_LOGE("save to file fail.");
return GRAPH_FAILED;
}
// Write file
ge::proto::ModelDef ge_proto;
if (buffer.GetData() != nullptr) {
std::string str((const char *)buffer.GetData(), buffer.GetSize());
if (!ge_proto.ParseFromString(str)) {
return GRAPH_FAILED;
}
char real_path[PATH_MAX] = {0x00};
if (strlen(file_name.c_str()) >= PATH_MAX) {
return GRAPH_FAILED;
}
if (realpath(file_name.c_str(), real_path) == nullptr) {
GELOGI("file %s does not exit, it will be created.", file_name.c_str());
}
int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS);
if (fd < 0) {
GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno));
return GRAPH_FAILED;
}
bool ret = ge_proto.SerializeToFileDescriptor(fd);
if (!ret) {
GELOGE(GRAPH_FAILED, "SerializeToFileDescriptor failed");
if (close(fd) != 0) {
GELOGE(GRAPH_FAILED, "close file descriptor fail.");
return GRAPH_FAILED;
}
return GRAPH_FAILED;
}
if (close(fd) != 0) {
GELOGE(GRAPH_FAILED, "close file descriptor fail.");
return GRAPH_FAILED;
}
if (!ret) {
GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed");
return GRAPH_FAILED;
}
}
return GRAPH_SUCCESS;
}

graphStatus Model::Load(ge::proto::ModelDef &model_def) {
ModelSerialize serialize;
*this = serialize.UnserializeModel(model_def);
return this->IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED;
}

bool Model::IsValid() const { return graph_.IsValid(); }

graphStatus Model::LoadFromFile(const string &file_name) {
char real_path[PATH_MAX] = {0x00};
if (strlen(file_name.c_str()) >= PATH_MAX) {
return GRAPH_FAILED;
}
if (realpath(file_name.c_str(), real_path) == nullptr) {
GELOGE(GRAPH_FAILED, "file %s does not exit, can not load.", file_name.c_str());
return GRAPH_FAILED;
}
int fd = open(real_path, O_RDONLY);
if (fd < 0) {
GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno));
return GRAPH_FAILED;
}

ge::proto::ModelDef model_def;
bool ret = model_def.ParseFromFileDescriptor(fd);
if (!ret) {
GELOGE(GRAPH_FAILED, "ParseFromFileDescriptor failed");
if (close(fd) != 0) {
GELOGE(GRAPH_FAILED, "close file descriptor fail.");
return GRAPH_FAILED;
}
return GRAPH_FAILED;
}
if (close(fd) != 0) {
GELOGE(GRAPH_FAILED, "close file descriptor fail.");
return GRAPH_FAILED;
}
if (!ret) {
GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed");
return GRAPH_FAILED;
}
return Load(model_def);
}

ProtoAttrMapHelper Model::MutableAttrMap() { return attrs_; }

ConstProtoAttrMapHelper Model::GetAttrMap() const {
return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg());
}
} // namespace ge

+ 0
- 763
metadef/graph/model_serialize.cc View File

@@ -1,763 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/model_serialize.h"
#include <google/protobuf/text_format.h>

#include <queue>
#include <iostream>

#include "debug/ge_attr_define.h"
#include "debug/ge_log.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/detail/model_serialize_imp.h"
#include "proto/ge_ir.pb.h"
#include "utils/graph_utils.h"
#include "debug/ge_op_types.h"

using std::map;
using std::string;

namespace ge {
bool ModelSerializeImp::ParseNodeIndex(const string &node_index, string &node_name, int32_t &index) {
auto sep = node_index.rfind(":");
if (sep == string::npos) {
GELOGW("separator is not found in node_index.");
return false;
}
node_name = node_index.substr(0, sep);
auto index_str = node_index.substr(sep + 1);
index = static_cast<int32_t>(std::strtol(index_str.c_str(), nullptr, 10));
return true;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeTensor(const ConstGeTensorPtr &tensor,
proto::TensorDef *tensor_proto) {
GE_CHK_BOOL_EXEC(tensor != nullptr, return false, "tensor is null.");
GE_CHK_BOOL_EXEC(tensor_proto != nullptr, return false, "tensor_proto is null.");

if (tensor->tensor_def_.GetProtoMsg() != nullptr) {
*tensor_proto = *tensor->tensor_def_.GetProtoMsg();
return true;
}
return false;
}

bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_proto) {
GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is null.");
GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");

op_def_proto->clear_input();
// Inputs
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
if (in_data_anchor != nullptr) {
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) {
op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" +
std::to_string(peer_out_anchor->GetIdx()));
} else {
op_def_proto->add_input("");
}
}
}
// Control edge
auto control_anchor = node->GetInControlAnchor();
if (control_anchor != nullptr) {
auto peer_out_anchors = control_anchor->GetPeerOutControlAnchors();
for (const auto &peer_out_anchor : peer_out_anchors) {
if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) {
op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":-1");
}
}
}
return true;
}

bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) {
GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null.");
GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");
if (op_desc->op_def_.GetProtoMsg() != nullptr) {
*op_def_proto = *op_desc->op_def_.GetProtoMsg();
// Delete unnecessary attr
if (is_dump) {
auto attr = op_def_proto->mutable_attr();
attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF);
attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF);
attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF);
GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP),
attr->erase(ATTR_NAME_WEIGHTS));
}
op_def_proto->clear_input_desc();
op_def_proto->clear_output_desc();
// Input descs
if (op_desc->GetAllInputsSize() > 0) {
auto size = static_cast<uint32_t>(op_desc->GetAllInputsSize());
for (uint32_t i = 0; i < size; i++) {
auto tensor_desc = op_desc->GetInputDescPtrDfault(i);
if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) {
*op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg());
}
}
}
// Output descs
if (op_desc->GetOutputsSize() > 0) {
auto size = static_cast<uint32_t>(op_desc->GetOutputsSize());
for (uint32_t i = 0; i < size; i++) {
auto tensor_desc = op_desc->GetOutputDescPtr(i);
if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) {
*op_def_proto->add_output_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg());
}
}
}

op_def_proto->set_id(op_desc->GetId());
for (const std::string &name : op_desc->GetSubgraphInstanceNames()) {
op_def_proto->add_subgraph_name(name);
}
OpDescToAttrDef(op_desc, op_def_proto);
}
return true;
}

void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) {
proto::AttrDef key_in;
proto::AttrDef value_in;
auto op_desc_attr = op_def_proto->mutable_attr();
if (!op_desc->input_name_idx_.empty()) {
for (auto &item : op_desc->input_name_idx_) {
key_in.mutable_list()->add_s(item.first);
value_in.mutable_list()->add_i(item.second);
}
op_desc_attr->insert({"_input_name_key", key_in});
op_desc_attr->insert({"_input_name_value", value_in});
}
proto::AttrDef key_out;
proto::AttrDef value_out;
if (!op_desc->output_name_idx_.empty()) {
for (auto &item : op_desc->output_name_idx_) {
key_out.mutable_list()->add_s(item.first);
value_out.mutable_list()->add_i(item.second);
}
op_desc_attr->insert({"_output_name_key", key_out});
op_desc_attr->insert({"_output_name_value", value_out});
}
proto::AttrDef opt_input;
if (!op_desc->optional_input_names_.empty()) {
for (auto &item : op_desc->optional_input_names_) {
opt_input.mutable_list()->add_s(item);
}
op_desc_attr->insert({"_opt_input", opt_input});
}
}

bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) {
if (node == nullptr || op_def_proto == nullptr) {
GELOGE(GRAPH_FAILED, "Input Para Node Invalid");
return false;
}
if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) {
GELOGE(GRAPH_FAILED, "Serialize OpDesc failed");
return false;
}
if (SerializeEdge(node, op_def_proto)) {
return true;
} else {
return false;
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph,
proto::GraphDef *graph_proto,
bool is_dump) {
if (graph == nullptr || graph_proto == nullptr) {
GELOGE(GRAPH_FAILED, "Input para Invalid");
return false;
}
graph_proto->set_name(graph->GetName());
// Inputs
for (const auto &input : graph->GetInputNodes()) {
if (input != nullptr) {
graph_proto->add_input(input->GetName() + ":0");
}
}
// Outputs
for (const auto &output : graph->GetGraphOutNodesInfo()) {
if (output.first != nullptr) {
graph_proto->add_output(output.first->GetName() + ":" + std::to_string(output.second));
GELOGI("Add output to graph proto, node name:%s, index:%ld", output.first->GetName().c_str(), output.second);
}
}
if (graph->attrs_.GetProtoMsg() != nullptr) {
*graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg();
}
for (const auto &node : graph->GetDirectNode()) {
if (!SerializeNode(node, graph_proto->add_op(), is_dump)) {
if (node->GetOpDesc() != nullptr) {
GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str());
}
return false;
}
}
return true;
}

bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) {
if (model_proto == nullptr) {
GELOGE(GRAPH_FAILED, "model_proto para Invalid");
return false;
}
model_proto->set_name(model.GetName());
model_proto->set_custom_version(model.GetPlatformVersion());
model_proto->set_version(model.GetVersion());
if (model.attrs_.GetProtoMsg()) {
*model_proto->mutable_attr() = *model.attrs_.GetProtoMsg();
}
auto &graph = model.graph_;
auto compute_graph = GraphUtils::GetComputeGraph(graph);
if (compute_graph == nullptr) {
GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr");
return false;
}
if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) {
GELOGE(GRAPH_FAILED, "SerializeGraph fail");
return false;
}

for (auto subgraph : compute_graph->GetAllSubgraphs()) {
if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) {
GELOGE(GRAPH_FAILED, "Serialize subgraph failed");
return false;
}
}

return true;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor(
GeTensorPtr &tensor, proto::TensorDef &tensor_proto) {
tensor = std::shared_ptr<GeTensor>(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto));
if (tensor == nullptr) {
GELOGE(GRAPH_FAILED, "tensor is nullptr");
return false;
} else {
return true;
}
}

void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc, std::vector<string> &key_in, std::vector<string> &key_out,
std::vector<uint32_t> &value_in, std::vector<uint32_t> &value_out,
std::vector<string> &opt_input) {
if (!key_in.empty()) {
if (key_in.size() != value_in.size()) {
GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(),
value_in.size());
} else {
for (uint32_t i = 0; i < key_in.size(); ++i) {
op_desc->input_name_idx_.insert(std::pair<string, uint32_t>(key_in.at(i), value_in.at(i)));
}
}
}
if (!key_out.empty()) {
if (key_out.size() != value_out.size()) {
GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(),
value_out.size());
} else {
for (uint32_t i = 0; i < key_out.size(); ++i) {
op_desc->output_name_idx_.insert(std::pair<string, uint32_t>(key_out.at(i), value_out.at(i)));
}
}
}
if (!opt_input.empty()) {
for (const auto &i : opt_input) {
op_desc->optional_input_names_.insert(i);
}
}
}

bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) {
std::vector<string> opt_input;
std::vector<string> key_in;
std::vector<uint32_t> value_in;
if (op_def_proto.attr().count("_opt_input") > 0) {
auto &name_list = op_def_proto.attr().at("_opt_input").list();
for (const auto &item_s : name_list.s()) {
opt_input.push_back(item_s);
}
auto op_desc_attr = op_def_proto.mutable_attr();
op_desc_attr->erase("_opt_input");
}
if (op_def_proto.attr().count("_input_name_key") > 0) {
auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list();
for (const auto &item_s : output_name_key_list.s()) {
key_in.push_back(item_s);
}
auto op_desc_attr = op_def_proto.mutable_attr();
op_desc_attr->erase("_input_name_key");
}
if (op_def_proto.attr().count("_input_name_value") > 0) {
auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list();
for (const auto &item_i : input_name_value_list.i()) {
value_in.push_back(static_cast<uint32_t>(item_i));
}
auto op_desc_attr = op_def_proto.mutable_attr();
op_desc_attr->erase("_input_name_value");
}
std::vector<string> key_out;
std::vector<uint32_t> value_out;
if (op_def_proto.attr().count("_output_name_key") > 0) {
auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list();
for (const auto &item_s : output_name_key_list.s()) {
key_out.push_back(item_s);
}
auto op_desc_attr = op_def_proto.mutable_attr();
op_desc_attr->erase("_output_name_key");
}
if (op_def_proto.attr().count("_output_name_value") > 0) {
auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list();
for (const auto &item_i : output_name_value_list.i()) {
value_out.push_back(static_cast<uint32_t>(item_i));
}
auto op_desc_attr = op_def_proto.mutable_attr();
op_desc_attr->erase("_output_name_value");
}

op_desc = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(protobuf_owner_, &op_def_proto));
GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr.");

// Input tensor
for (auto &input_desc : *op_def_proto.mutable_input_desc()) {
std::shared_ptr<GeTensorDesc> temp_value =
std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc));
GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
op_desc->inputs_desc_.push_back(temp_value);
}
// Output tensor
for (auto &output_desc : *op_def_proto.mutable_output_desc()) {
std::shared_ptr<GeTensorDesc> temp_value =
std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc));
GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
op_desc->outputs_desc_.push_back(temp_value);
}

op_desc->SetId(op_def_proto.id());
uint32_t graph_index = 0;
for (const std::string &name : op_def_proto.subgraph_name()) {
op_desc->AddSubgraphName(name);
op_desc->SetSubgraphInstanceName(graph_index++, name);
}

// insert name index by key and value
AttrDefToOpDesc(op_desc, key_in, key_out, value_in, value_out, opt_input);

return true;
}

bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto) {
GE_RT_FALSE_CHECK_NOTNULL(graph);
OpDescPtr op_desc = nullptr;
if (!UnserializeOpDesc(op_desc, op_def_proto)) {
GELOGW("UnserializeOpDesc error.");
}

NodePtr node = graph->AddNode(op_desc, op_desc->GetId());
GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr.");

// Inputs
int dst_index = 0;
for (const auto &input : op_def_proto.input()) {
string node_name;
int32_t index = 0;
if (ParseNodeIndex(input, node_name, index)) {
node_input_node_names_.push_back(NodeNameNodeReq{node_name, index, node, dst_index, op_def_proto.name()});
}
if (index >= 0) {
dst_index++;
}
}
node_map_[op_def_proto.name()] = node;
return true;
}

bool ModelSerializeImp::HandleNodeNameRef() {
// Edges
for (auto &item : node_input_node_names_) {
auto src_node_it = node_map_.find(item.src_node_name);
if (src_node_it == node_map_.end()) {
GELOGE(GRAPH_FAILED, "cannot find node %s", item.src_node_name.c_str());
return false;
}
GE_IF_BOOL_EXEC(src_node_it->second == nullptr || item.dst_node == nullptr, continue);
if (item.src_out_index >= 0) {
auto src_anchor = src_node_it->second->GetOutDataAnchor(item.src_out_index);
auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index);
if (src_anchor == nullptr || dst_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "get anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
item.dst_node_name.c_str(), item.dst_in_index);
return false;
}
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737
} else {
// Control edge
auto src_anchor = src_node_it->second->GetOutControlAnchor();
auto dst_anchor = item.dst_node->GetInControlAnchor();
if (src_anchor != nullptr && dst_anchor != nullptr) {
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737
}
}
}
// Graph input
for (auto &item : graph_input_node_names_) {
auto node_it = node_map_.find(item.node_name);
if (node_it == node_map_.end()) {
GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str());
return false;
}
GE_IF_BOOL_EXEC(item.graph == nullptr, continue);
auto ret = item.graph->AddInputNode(node_it->second);
if (ret == nullptr) {
return false;
}
}
// Graph output
for (auto &item : graph_output_node_names_) {
auto node_it = node_map_.find(item.node_name);
if (node_it == node_map_.end()) {
GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str());
return false;
}

GE_IF_BOOL_EXEC(item.graph == nullptr, continue);
auto ret = item.graph->AddOutputNodeByIndex(node_it->second, item.index);
GELOGI("node name:%s, item.index:%ld", node_it->second->GetName().c_str(), item.index);
if (ret == nullptr) {
GELOGE(GRAPH_FAILED, "AddOutputNode failed.");
return false;
}
}
node_input_node_names_.clear();
graph_input_node_names_.clear();
graph_output_node_names_.clear();
node_map_.clear();
return true;
}

bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map<string, ComputeGraphPtr> &subgraphs) {
std::queue<ComputeGraphPtr> all_graphs;
all_graphs.emplace(compute_graph);
while (!all_graphs.empty()) {
ComputeGraphPtr graph = all_graphs.front();
all_graphs.pop();

for (const NodePtr &node : graph->GetDirectNode()) {
const OpDescPtr op_desc = node->GetOpDesc();
for (const std::string &name : op_desc->GetSubgraphInstanceNames()) {
auto it = subgraphs.find(name);
if (it == subgraphs.end()) {
GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), name.c_str(),
subgraphs.size());
return false;
}

ComputeGraphPtr &subgraph = it->second;
subgraph->SetParentGraph(graph);
subgraph->SetParentNode(node);
compute_graph->AddSubgraph(subgraph->GetName(), subgraph);
all_graphs.emplace(subgraph);
}
}
}

return true;
}

bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) {
model.name_ = model_proto.name();
model.version_ = model_proto.version();
model.platform_version_ = model_proto.custom_version();
model.attrs_ = ProtoAttrMapHelper(protobuf_owner_, model_proto.mutable_attr());

auto &graphs_proto = *model_proto.mutable_graph();
if (!graphs_proto.empty()) {
auto &graph_proto = graphs_proto[0];
ComputeGraphPtr compute_graph_ptr;
if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) {
model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr);
}

// 0 is main graph, following is subgraph.
map<string, ComputeGraphPtr> subgraphs;
for (int idx = 1; idx < graphs_proto.size(); ++idx) {
ComputeGraphPtr subgraph;
ModelSerializeImp impl;
if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) {
GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed");
return false;
}

if (!impl.HandleNodeNameRef()) {
GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed");
return false;
}

subgraphs[subgraph->GetName()] = subgraph;
}

if (!RebuildOwnership(compute_graph_ptr, subgraphs)) {
GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed");
return false;
}
}

if (!HandleNodeNameRef()) {
GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed");
return false;
}
return true;
}

bool ModelSerializeImp::UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto) {
graph = ComGraphMakeShared<ComputeGraph>(graph_proto.name());
if (graph == nullptr) {
GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed");
return false;
}

// Inputs
for (auto input : graph_proto.input()) {
string node_name;
int32_t index;
if (ParseNodeIndex(input, node_name, index)) {
graph_input_node_names_.push_back(NodeNameGraphReq{node_name, index, graph});
}
}
// Outputs
for (auto output : graph_proto.output()) {
string node_name;
int32_t index;
if (ParseNodeIndex(output, node_name, index)) {
graph_output_node_names_.push_back(NodeNameGraphReq{node_name, index, graph});
}
}
graph->attrs_ = ProtoAttrMapHelper(protobuf_owner_, graph_proto.mutable_attr());
for (auto &op_def_proto : *graph_proto.mutable_op()) {
if (!UnserializeNode(graph, op_def_proto)) {
GELOGE(GRAPH_FAILED, "UnserializeNode fail");
return false;
}
}
return true;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeGraph(ComputeGraphPtr &graph,
proto::GraphDef &graph_proto) {
if (!UnserializeGraphWithoutEdge(graph, graph_proto)) {
GELOGW("UnserializeGraphWithoutEdge fail");
}
if (!HandleNodeNameRef()) {
GELOGE(GRAPH_FAILED, "Link Anchor or set graph input or output fail");
return false;
}
return true;
}

bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf::Message *proto) {
GE_CHK_BOOL_EXEC(data != nullptr, return false, "data is null.");
GE_CHK_BOOL_EXEC(proto != nullptr, return false, "proto is null.");

google::protobuf::io::CodedInputStream coded_stream(data, len);
// 2048M -1
coded_stream.SetTotalBytesLimit(INT32_MAX, -1);
if (!proto->ParseFromCodedStream(&coded_stream)) {
GELOGE(GRAPH_FAILED, "ReadProtoFromBinaryFile failed len %zu", len);
return false;
}
return true;
}

Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) {
proto::ModelDef model_def;
ModelSerializeImp imp;
if (!imp.SerializeModel(model, &model_def, is_dump)) {
return Buffer();
}
#if !defined(__ANDROID__) && !defined(ANDROID)
Buffer buffer(model_def.ByteSizeLong());
#else
Buffer buffer(model_def.ByteSize());
#endif
GE_CHK_BOOL_ONLY_LOG(buffer.GetSize() != 0, "get size failed");
GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
auto ret = model_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
if (ret != true) {
GELOGW("serialize to array fail.");
}
return buffer;
}

size_t ModelSerialize::GetSerializeModelSize(const Model &model) {
proto::ModelDef model_def;
ModelSerializeImp imp;
if (!imp.SerializeModel(model, &model_def)) {
return 0;
}
#if !defined(__ANDROID__) && !defined(ANDROID)
return model_def.ByteSizeLong();
#else
return model_def.ByteSize();
#endif
}

Model ModelSerialize::UnserializeModel(const uint8_t *data, size_t len) {
if (data == nullptr) {
GELOGE(GRAPH_FAILED, "data is nullptr");
return Model();
}

std::shared_ptr<proto::ModelDef> model_proto_ptr;
model_proto_ptr = ComGraphMakeShared<proto::ModelDef>();
if (model_proto_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed");
return Model();
}

auto &model_proto = *model_proto_ptr;
if (!ReadProtoFromBinaryFile(data, len, &model_proto)) {
GELOGE(GRAPH_FAILED, "ParseFromArray fail");
return Model();
}

Model model;
ModelSerializeImp imp;
imp.SetProtobufOwner(model_proto_ptr);
if (!imp.UnserializeModel(model, model_proto)) {
GELOGE(GRAPH_FAILED, "Unserialize Model fail");
return Model();
}
return model;
}

Model ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def) {
std::shared_ptr<proto::ModelDef> model_def_ptr = ComGraphMakeShared<proto::ModelDef>(model_def);
GE_CHK_BOOL_EXEC(model_def_ptr != nullptr, return Model(), "mode_def make shared failed");

ModelSerializeImp imp;
imp.SetProtobufOwner(model_def_ptr);
Model model;
if (!imp.UnserializeModel(model, *model_def_ptr)) {
GELOGE(GRAPH_FAILED, "Unserialize Model fail");
return Model();
}
return model;
}

Buffer ModelSerialize::SerializeGraph(const ComputeGraphPtr &graph) {
proto::GraphDef graph_def;
ModelSerializeImp imp;
if (!imp.SerializeGraph(graph, &graph_def)) {
return Buffer();
}
#if !defined(__ANDROID__) && !defined(ANDROID)
Buffer buffer(graph_def.ByteSizeLong());
#else
Buffer buffer(graph_def.ByteSize());
#endif
GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed");
GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
auto ret = graph_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
if (ret != true) {
GE_LOGE("serialize to array fail.");
}

return buffer;
}

ComputeGraphPtr ModelSerialize::UnserializeGraph(const uint8_t *data, size_t len) {
if (data == nullptr) {
GELOGE(GRAPH_FAILED, "data is nullptr");
return nullptr;
}

std::shared_ptr<proto::GraphDef> graph_proto_ptr;
graph_proto_ptr = ComGraphMakeShared<proto::GraphDef>();
if (graph_proto_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
return nullptr;
}
proto::GraphDef &graph_proto = *graph_proto_ptr;
if (!ReadProtoFromBinaryFile(data, len, &graph_proto)) {
GELOGE(GRAPH_FAILED, "ParseFromArray fail");
return nullptr;
}

ComputeGraphPtr graph;
ModelSerializeImp imp;
imp.SetProtobufOwner(graph_proto_ptr);
if (!imp.UnserializeGraph(graph, graph_proto)) {
return nullptr;
}
return graph;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer ModelSerialize::SerializeOpDesc(const ConstOpDescPtr &op_desc) {
proto::OpDef op_def;
ModelSerializeImp imp;
if (!imp.SerializeOpDesc(op_desc, &op_def)) {
return Buffer();
}
#if !defined(__ANDROID__) && !defined(ANDROID)
Buffer buffer(op_def.ByteSizeLong());
#else
Buffer buffer(op_def.ByteSize());
#endif
GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed");
GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
auto ret = op_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
if (ret != true) {
GE_LOGE("serialize to array fail.");
}

return buffer;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr ModelSerialize::UnserializeOpDesc(const uint8_t *data,
size_t len) {
if (data == nullptr) {
GELOGE(GRAPH_FAILED, "data is nullptr");
return nullptr;
}

std::shared_ptr<proto::OpDef> op_def_ptr;
op_def_ptr = ComGraphMakeShared<proto::OpDef>();
if (op_def_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
return nullptr;
}
proto::OpDef &op_def = *op_def_ptr;
if (!ReadProtoFromBinaryFile(data, len, &op_def)) {
GELOGE(GRAPH_FAILED, "ParseFromArray fail");
return nullptr;
}

OpDescPtr op_desc;
ModelSerializeImp imp;
imp.SetProtobufOwner(op_def_ptr);
if (!imp.UnserializeOpDesc(op_desc, op_def)) {
GELOGW("UnserializeOpDesc error.");
}
return op_desc;
}
} // namespace ge

+ 0
- 3
metadef/graph/module.mk View File

@@ -1,3 +0,0 @@
LOCAL_PATH := $(call my-dir)

include $(LOCAL_PATH)/graph.mk

+ 0
- 877
metadef/graph/node.cc View File

@@ -1,877 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/node.h"
#include <utility>
#include "debug/ge_op_types.h"
#include "debug/ge_util.h"
#include "external/graph/operator_factory.h"
#include "framework/common/debug/ge_log.h"
#include "graph/ge_tensor.h"
#include "graph/operator_factory_impl.h"
#include "graph/shape_refiner.h"
#include "utils/ge_ir_utils.h"
#include "utils/node_utils.h"
#include "utils/op_desc_utils.h"
#include "common/util/error_manager/error_manager.h"

using std::string;
using std::vector;

namespace ge {
Node::Node(const OpDescPtr &op, const ComputeGraphPtr &owner_graph)
: op_(op),
owner_graph_(owner_graph),
in_data_anchors_(),
out_data_anchors_(),
in_control_anchor_(nullptr),
out_control_anchor_(nullptr),
attrs_(),
has_init_(false) {
anchor_status_updated_ = false;
}

Node::~Node() {
for (const auto &in_data_anchor : in_data_anchors_) {
if (in_data_anchor != nullptr) {
in_data_anchor->UnlinkAll();
}
}
for (const auto &out_data_anchor : out_data_anchors_) {
if (out_data_anchor != nullptr) {
out_data_anchor->UnlinkAll();
}
}
if (in_control_anchor_ != nullptr) {
in_control_anchor_->UnlinkAll();
}
if (out_control_anchor_ != nullptr) {
out_control_anchor_->UnlinkAll();
}
}

graphStatus Node::Init() {
if (has_init_) {
return GRAPH_SUCCESS;
}
GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr");
size_t size = op_->GetAllInputsSize();
for (size_t i = 0; i < size; i++) {
std::shared_ptr<InDataAnchor> anchor = ComGraphMakeShared<InDataAnchor>(shared_from_this(), i);
if (anchor == nullptr) {
GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed.");
return GRAPH_FAILED;
}
in_data_anchors_.push_back(anchor);
}
size = op_->GetOutputsSize();
for (size_t i = 0; i < size; i++) {
std::shared_ptr<OutDataAnchor> anchor = ComGraphMakeShared<OutDataAnchor>(shared_from_this(), i);
if (anchor == nullptr) {
GELOGE(GRAPH_FAILED, "Current out_data_anchor is null, malloc shared_ptr failed.");
return GRAPH_FAILED;
}
out_data_anchors_.push_back(anchor);
}
in_control_anchor_ = ComGraphMakeShared<InControlAnchor>(shared_from_this(), -1);
out_control_anchor_ = ComGraphMakeShared<OutControlAnchor>(shared_from_this(), -1);
if (in_control_anchor_ == nullptr || out_control_anchor_ == nullptr) {
GELOGE(GRAPH_FAILED, "Current in_control_anchor or out_control_anchor is null, malloc shared_ptr failed.");
return GRAPH_FAILED;
}
has_init_ = true;
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetName() const {
GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr");
return op_->GetName();
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetType() const {
GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr");
return op_->GetType();
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAttrsAreEqual(const Node &r_node) const {
const auto &attr_map = this->attrs_;
const auto &r_attr_map = r_node.attrs_;
// 1.Verify node's map<string, AttrValue> size
if (attr_map.size() != r_attr_map.size()) {
GELOGE(GRAPH_FAILED, "Size of node's attr map verify failed, node name: %s.", this->GetName().c_str());
return false;
}
// 2.Verify node's map<string, AttrValue> key, verify values is temporarily not implemented
for (const auto &it : attr_map) {
if (r_attr_map.count(it.first) == 0) {
GELOGE(GRAPH_FAILED, "Key of node's attr map verify failed, node name: %s key name: %s.", this->GetName().c_str(),
it.first.c_str());
return false;
}
}
return true;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeMembersAreEqual(const Node &r_node) const {
return ((((this->op_ != nullptr) && (r_node.op_ != nullptr) && (IsEqual(*(this->op_), *(r_node.op_), "node.op_"))) ||
((this->op_ == nullptr) && (r_node.op_ == nullptr))) &&
IsEqual(this->has_init_, r_node.has_init_, "node.has_init_") &&
IsEqual(this->anchor_status_updated_, r_node.anchor_status_updated_, "node.anchor_status_updated_") &&
IsEqual(this->send_event_id_list_, r_node.send_event_id_list_, "node.send_event_id_list_") &&
IsEqual(this->recv_event_id_list_, r_node.recv_event_id_list_, "node.recv_event_id_list_"));
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(const AnchorPtr &left_anchor,
const AnchorPtr &right_anchor,
size_t i) const {
GE_IF_BOOL_EXEC(left_anchor == nullptr, GELOGE(GRAPH_FAILED, "left_anchor is null."); return false);
GE_IF_BOOL_EXEC(right_anchor == nullptr, GELOGE(GRAPH_FAILED, "right_anchor is null."); return false);

const auto anchor_peer_size = left_anchor->GetPeerAnchors().size();
const auto right_anchor_peer_size = right_anchor->GetPeerAnchors().size();
// Firstly, verify anchor's peer anchors size equal or not
if (anchor_peer_size != right_anchor_peer_size) {
GELOGE(GRAPH_FAILED,
"Size of anchor's peer anchors verify failed, node name: %s "
"anchor_peer_size [%zu] is different form [%zu] at index [%zu].",
this->GetName().c_str(), anchor_peer_size, right_anchor_peer_size, i);
return false;
}
// Secondly, verify anchor's peer anchor owner node equal or not
for (size_t j = 0; j < anchor_peer_size; j++) {
const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode();
const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode();
if (peer_node == nullptr || r_peer_node == nullptr) {
GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ",
this->GetName().c_str(), i, j);
return false;
}
// Determine the connection relationship by linking the node's name
if (peer_node->GetName() != r_peer_node->GetName()) {
GELOGE(GRAPH_FAILED,
"anchor's peer node name verify failed, node name: %s index[%zu]"
"peer node name %s is different from %s at index [%zu].",
this->GetName().c_str(), i, peer_node->GetName().c_str(), r_peer_node->GetName().c_str(), j);
return false;
}
}
return true;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeInConnectsAreEqual(const Node &r_node) const {
// 1.Verify all in data and control anchors size
const auto in_data_anchor_size = this->GetAllInDataAnchors().size();
const auto r_in_data_anchor_size = r_node.GetAllInDataAnchors().size();
if (in_data_anchor_size != r_in_data_anchor_size) {
GELOGE(GRAPH_FAILED, "Size of node's in data anchors verify failed, node name: %s.", this->GetName().c_str());
return false;
}
const auto l_in_anchors = this->GetAllInAnchors();
const auto r_in_anchors = r_node.GetAllInAnchors();
// Data anchors size equal, all anchors size not equal, means control anchor size not equal
const auto in_control_anchor_size = l_in_anchors.size() - in_data_anchor_size;
const auto r_in_control_anchor_size = r_in_anchors.size() - r_in_data_anchor_size;
if (in_control_anchor_size != r_in_control_anchor_size) {
GELOGE(GRAPH_FAILED, "Size of node's in control anchors verify failed, node name: %s.", this->GetName().c_str());
return false;
}
// 2.Verify all in data and control anchors connect info
for (size_t i = 0; i < this->GetAllInAnchors().size(); i++) {
// Verify data anchors
if (i < in_data_anchor_size) {
const auto &in_anchor = l_in_anchors.at(i);
const auto &r_in_anchor = r_in_anchors.at(i);
if (!(NodeAnchorIsEqual(in_anchor, r_in_anchor, i))) {
GELOGE(GRAPH_FAILED, "Node's in data control anchor verify failed, node name: %s.", this->GetName().c_str());
return false;
}
} else {
// Verify control anchors
const auto &in_control_anchor = l_in_anchors.at(i);
const auto &r_in_control_anchor = r_in_anchors.at(i);
if (!(NodeAnchorIsEqual(in_control_anchor, r_in_control_anchor, i - in_data_anchor_size))) {
GELOGE(GRAPH_FAILED, "Node's in control anchor verify failed, node name: %s.", this->GetName().c_str());
return false;
}
}
}
return true;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeOutConnectsAreEqual(const Node &r_node) const {
// 1.Verify all out data and control anchors size
const auto l_out_data_anchors = this->GetAllOutDataAnchors();
const auto r_out_data_anchors = r_node.GetAllOutDataAnchors();
const auto out_data_anchor_size = l_out_data_anchors.size();
const auto r_out_data_anchor_size = r_out_data_anchors.size();
if (out_data_anchor_size != r_out_data_anchor_size) {
GELOGE(GRAPH_FAILED, "Size of node's out data anchors verify failed, node name: %s.", this->GetName().c_str());
return false;
}
const auto l_out_anchors = this->GetAllOutAnchors();
const auto r_out_anchors = r_node.GetAllOutAnchors();
// Data anchors size equal, all anchors size not equal, means control anchor size not equal
const auto out_control_anchor_size = l_out_anchors.size() - out_data_anchor_size;
const auto r_out_control_anchor_size = r_out_anchors.size() - r_out_data_anchor_size;
if (out_control_anchor_size != r_out_control_anchor_size) {
GELOGE(GRAPH_FAILED, "Size of node's out control anchors verify failed, node name: %s.", this->GetName().c_str());
return false;
}

// 2.Verify all out data and control anchors connect info
for (size_t i = 0; i < this->GetAllOutAnchors().size(); i++) {
// Verify data anchors
if (i < out_data_anchor_size) {
const auto &out_anchor = l_out_data_anchors.at(i);
const auto &r_out_anchor = r_out_data_anchors.at(i);
if (!(NodeAnchorIsEqual(out_anchor, r_out_anchor, i))) {
GELOGE(GRAPH_FAILED, "Node's out data control anchor verify failed, node name: %s.", this->GetName().c_str());
return false;
}
} else {
// Verify control anchors
const auto &out_control_anchor = l_out_anchors.at(i);
const auto &r_out_control_anchor = r_out_anchors.at(i);
if (!(NodeAnchorIsEqual(out_control_anchor, r_out_control_anchor, i - out_data_anchor_size))) {
GELOGE(GRAPH_FAILED, "Node's out control anchor verify failed, node name: %s.", this->GetName().c_str());
return false;
}
}
}
return true;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::operator==(const Node &r_node) const {
return (NodeMembersAreEqual(r_node) && NodeAttrsAreEqual(r_node) && NodeInConnectsAreEqual(r_node) &&
NodeOutConnectsAreEqual(r_node));
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const NodePtr &input_node) {
// This function is deprecated, please use other two overloaded functions
GE_CHECK_NOTNULL(input_node);
// Input_node ---> this
auto out_anchors = input_node->GetAllOutDataAnchors();
if (out_anchors.size() != 1) {
GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size());
return GRAPH_PARAM_INVALID;
}
GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr");
auto op_desc = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);

if (op_->AddInputDesc(op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "add input desc failed.");
return GRAPH_FAILED;
}
std::shared_ptr<InDataAnchor> anchor = ComGraphMakeShared<InDataAnchor>(shared_from_this(), in_data_anchors_.size());
if (anchor == nullptr) {
GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size());
return GRAPH_FAILED;
}
in_data_anchors_.push_back(anchor);
(void)out_anchors.at(0)->LinkTo(in_data_anchors_.back());

return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const uint32_t &index,
NodePtr input_node) {
GE_CHECK_NOTNULL(input_node);
// Input_node ---> this
auto out_anchors = input_node->GetAllOutDataAnchors();
if (out_anchors.size() != 1) {
GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size());
return GRAPH_PARAM_INVALID;
}

GE_CHECK_NOTNULL(op_);
auto op_desc = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);

if (op_->AddInputDesc(index, op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "add input desc failed.");
return GRAPH_FAILED;
}

if (index < GetAllInDataAnchors().size()) {
(void)out_anchors.at(0)->LinkTo(in_data_anchors_[index]);
} else {
std::shared_ptr<InDataAnchor> anchor =
ComGraphMakeShared<InDataAnchor>(shared_from_this(), in_data_anchors_.size());
if (anchor == nullptr) {
GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size());
return GRAPH_FAILED;
}
in_data_anchors_.push_back(anchor);
(void)out_anchors.at(0)->LinkTo(in_data_anchors_.back());
}

return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFromForParse(const NodePtr &input_node) {
// This function is used for ParseWeights.
GE_CHECK_NOTNULL(input_node);
// Input_node ---> this
auto out_anchors = input_node->GetAllOutDataAnchors();
if (out_anchors.size() != 1) {
GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size());
return GRAPH_PARAM_INVALID;
}

std::shared_ptr<InDataAnchor> anchor = ComGraphMakeShared<InDataAnchor>(shared_from_this(), in_data_anchors_.size());
if (anchor == nullptr) {
GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, make anchor failed", out_anchors.size());
return GRAPH_FAILED;
}
in_data_anchors_.push_back(anchor);
(void)out_anchors.at(0)->LinkTo(in_data_anchors_.back());

return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const string &name, NodePtr input_node) {
GE_CHECK_NOTNULL(input_node);
// Input_node ---> this
auto out_anchors = input_node->GetAllOutDataAnchors();
if (out_anchors.size() != 1) {
GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size());
return GRAPH_PARAM_INVALID;
}

GE_CHECK_NOTNULL(op_);
auto input_op_desc = input_node->GetOpDesc();
GE_CHECK_NOTNULL(input_op_desc);
auto index = op_->GetInputIndexByName(name);
if (index != -1) {
if (index >= static_cast<int>(in_data_anchors_.size())) {
GELOGE(GRAPH_FAILED, "op %s get input name %s 's index %d is illegal.", op_->GetName().c_str(), name.c_str(),
index);
return GRAPH_FAILED;
}
(void)out_anchors.at(0)->LinkTo(in_data_anchors_[index]);
} else {
std::shared_ptr<InDataAnchor> anchor =
ComGraphMakeShared<InDataAnchor>(shared_from_this(), in_data_anchors_.size());
if (anchor == nullptr) {
GELOGE(GRAPH_FAILED, "in_data_anchors_size is:%zu, malloc shared_ptr failed.", in_data_anchors_.size());
return GRAPH_FAILED;
}
in_data_anchors_.push_back(anchor);
(void)out_anchors.at(0)->LinkTo(in_data_anchors_.back());
}
if (op_->AddInputDesc(name, input_op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "add input desc failed.");
return GRAPH_FAILED;
}

return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr Node::GetOwnerComputeGraph() const {
return owner_graph_.lock();
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::SetOwnerComputeGraph(const ComputeGraphPtr &graph) {
if (graph == nullptr) {
return GRAPH_PARAM_INVALID;
}
owner_graph_ = graph;
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<InDataAnchorPtr> Node::GetAllInDataAnchors() const {
return Vistor<InDataAnchorPtr>(shared_from_this(), in_data_anchors_);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<OutDataAnchorPtr> Node::GetAllOutDataAnchors() const {
return Vistor<OutDataAnchorPtr>(shared_from_this(), out_data_anchors_);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllInDataAnchorsSize() const {
return in_data_anchors_.size();
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllOutDataAnchorsSize() const {
return out_data_anchors_.size();
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::GetAllInAnchors() const {
std::vector<AnchorPtr> vec;
// Push back in_data_anchors_
for (const auto &in_anchor_iter : Vistor<InDataAnchorPtr>(shared_from_this(), in_data_anchors_)) {
auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_anchor_iter);
if (in_anchor != nullptr) {
vec.push_back(in_anchor);
}
}
// Push back in_control_anchor_
if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) ||
(in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) {
auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_control_anchor_);
if (in_anchor != nullptr) {
vec.push_back(in_anchor);
}
}
return Node::Vistor<AnchorPtr>(shared_from_this(), vec);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::GetAllOutAnchors() const {
std::vector<AnchorPtr> vec;
// Push back out_data_anchors_
for (const auto &out_anchor_iter : Vistor<OutDataAnchorPtr>(shared_from_this(), out_data_anchors_)) {
auto out_anchor = Anchor::DynamicAnchorCast<Anchor>(out_anchor_iter);
if (out_anchor != nullptr) {
vec.push_back(out_anchor);
}
}
// Push back out_control_anchor_
if (out_control_anchor_->GetPeerInControlAnchors().size() > 0 ||
out_control_anchor_->GetPeerInDataAnchors().size() > 0) {
auto out_anchor = Anchor::DynamicAnchorCast<Anchor>(out_control_anchor_);
if (out_anchor != nullptr) {
vec.push_back(out_anchor);
}
}
return Node::Vistor<AnchorPtr>(shared_from_this(), vec);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const {
if (idx < 0 || idx >= static_cast<int>(in_data_anchors_.size())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19019", {"opname", "index", "anchorname", "optype"},
{GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
return nullptr;
} else {
return in_data_anchors_[idx];
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const {
// Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_
if (idx < -1 || idx >= static_cast<int>(in_data_anchors_.size())) {
GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str());
return nullptr;
} else {
// Return control anchor
if (idx == -1) {
auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_control_anchor_);
return in_anchor;
}
// Return data anchor
return in_data_anchors_[idx];
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const {
// Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_
if (idx < -1 || idx >= static_cast<int>(out_data_anchors_.size())) {
ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"},
{
GetName().c_str(),
std::to_string(idx),
"out_anchor",
GetType().c_str(),
});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
return nullptr;
} else {
// Return control anchor
if (idx == -1) {
auto out_anchor = Anchor::DynamicAnchorCast<Anchor>(out_control_anchor_);
return out_anchor;
}
// Return data anchor
return out_data_anchors_[idx];
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const {
if (idx < 0 || idx >= static_cast<int>(out_data_anchors_.size())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19019", {"opname", "index", "anchorname", "optype"},
{GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
return nullptr;
} else {
return out_data_anchors_[idx];
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchorPtr Node::GetInControlAnchor() const {
return in_control_anchor_;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchorPtr Node::GetOutControlAnchor() const {
return out_control_anchor_;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetInNodes() const {
std::vector<NodePtr> vec;
for (const auto &in_anchor : in_data_anchors_) {
GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr");
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) {
continue;
}
auto node = out_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
if (in_control_anchor_ != nullptr) {
if (in_control_anchor_->IsPeerOutAnchorsEmpty()) {
return Node::Vistor<NodePtr>(shared_from_this(), vec);
}

auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors();
for (const auto &out_anchor : peer_out_anchors) {
GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "in_control_anchor_ peer out data anchors is nullptr");
auto node = out_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}

auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors();
for (const auto &out_control_anchor : peer_out_control_anchors) {
GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue,
"in_control_anchor_ peer out control anchors is nullptr");
auto node = out_control_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
}
return Node::Vistor<NodePtr>(shared_from_this(), vec);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::IsAllInNodesSeen(
std::unordered_set<Node *> &nodes_seen) const {
for (const auto &in_anchor : in_data_anchors_) {
GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr");
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) {
continue;
}
auto node = out_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) {
continue;
}
if (nodes_seen.count(node.get()) == 0) {
return false;
}
}

if (in_control_anchor_ != nullptr) {
if (in_control_anchor_->IsPeerOutAnchorsEmpty()) {
return true;
}
auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors();
for (const auto &out_control_anchor : peer_out_control_anchors) {
GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, "out_control_anchor is nullptr");
auto node = out_control_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) {
continue;
}
if (nodes_seen.count(node.get()) == 0) {
return false;
}
}
}

return true;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetInDataNodes() const {
std::vector<NodePtr> vec;
for (const auto &in_anchor : in_data_anchors_) {
GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr");
auto anchor_ptr = in_anchor->GetPeerOutAnchor();
if (anchor_ptr == nullptr) {
continue;
}
auto node = anchor_ptr->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
return Node::Vistor<NodePtr>(shared_from_this(), vec);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetInControlNodes() const {
std::vector<NodePtr> vec;
if (in_control_anchor_ != nullptr) {
for (const auto &in_anchor : in_control_anchor_->GetPeerOutControlAnchors()) {
GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerOutControlAnchors is nullptr");
auto node = in_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
}
return Node::Vistor<NodePtr>(shared_from_this(), vec);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetOutNodes() const {
std::vector<NodePtr> vec;
for (const auto &out_anchor : out_data_anchors_) {
GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr");
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHK_BOOL_EXEC((peer_in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr");
auto node = peer_in_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
}
if (out_control_anchor_ != nullptr) {
auto peer_in_control_anchors = out_control_anchor_->GetPeerInControlAnchors();
for (const auto &in_control_anchor : peer_in_control_anchors) {
GE_CHK_BOOL_EXEC(in_control_anchor != nullptr, continue,
"out_control_anchor_ peer in control anchors is nullptr");
auto node = in_control_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
}
return Node::Vistor<NodePtr>(shared_from_this(), vec);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetInAllNodes() const {
std::vector<NodePtr> vec;
for (const auto &in_node : GetInDataNodes()) {
vec.push_back(in_node);
}
for (const auto &in_control_node : GetInControlNodes()) {
vec.push_back(in_control_node);
}
return Node::Vistor<NodePtr>(shared_from_this(), vec);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetOutDataNodes() const {
std::vector<NodePtr> vec;
for (const auto &out_anchor : out_data_anchors_) {
GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr");
for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr");
auto node = in_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
}
return Node::Vistor<NodePtr>(shared_from_this(), vec);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetOutDataNodesSize() const {
uint32_t out_nums = 0;
for (const auto &out_anchor : out_data_anchors_) {
GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr");
out_nums += out_anchor->GetPeerInDataNodesSize();
}
return out_nums;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetOutControlNodes() const {
std::vector<NodePtr> vec;

for (const auto &out_anchor : out_data_anchors_) {
GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr");
for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) {
GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInControlAnchors is nullptr");
auto node = in_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
}

if (out_control_anchor_ != nullptr) {
for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) {
GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr");
auto node = in_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
}

return Node::Vistor<NodePtr>(shared_from_this(), vec);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetOutAllNodes() const {
std::vector<NodePtr> vec;
for (const auto &out_anchor : out_data_anchors_) {
GE_CHK_BOOL_EXEC((out_anchor != nullptr), { continue; }, "out_data_anchors_ is nullptr");
for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHK_BOOL_EXEC((in_anchor != nullptr), { continue; }, "GetPeerInDataAnchors is nullptr");
auto node = in_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) {
GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr");
auto node = in_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
}

if (out_control_anchor_ != nullptr) {
for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) {
GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr");
auto node = in_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
}
}
return Node::Vistor<NodePtr>(shared_from_this(), vec);
}

graphStatus Node::InferShapeAndType() const {
Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this());
graphStatus ret = ShapeRefiner::InferShapeAndType(shared_from_this(), op);
return ret;
}

graphStatus Node::InferOriginFormat() const {
Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this());
// Get infer func and execute
GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr");
return op_->CallInferFormatFunc(op);
}
graphStatus Node::Verify() const {
const string data_type = "Data";
const string aipp_data_type = "AippData";
const string const_type = "Const";
const string variable_type = "Variable";
bool is_unknown_graph = GetOwnerComputeGraph()->GetGraphUnknownFlag();
GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr");

if (!is_unknown_graph) {
for (const auto &in_anchor_ptr : GetAllInDataAnchors()) {
GE_IF_BOOL_EXEC(in_anchor_ptr == nullptr, GELOGW("in anchor ptr is null"); continue);
bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type ||
op_->GetType() == const_type || op_->GetType() == variable_type ||
op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0;
if (!valid_anchor) {
ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"opname", "index"},
{GetName(), std::to_string(in_anchor_ptr->GetIdx())});
GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx());
return GRAPH_FAILED;
}
}
}

string frameworkop_type = "FrameworkOp";
bool need_update_name = op_->GetType() != frameworkop_type && !is_unknown_graph;
if (need_update_name) {
auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_->GetType());
if (node_op.IsEmpty()) {
GELOGW("get op from OperatorFactory fail. opType: %s", op_->GetType().c_str());
} else {
GELOGD("get op from OperatorFactory success. opType: %s", op_->GetType().c_str());
auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
if (temp_op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "temp op desc is null");
return GRAPH_FAILED;
}
if (!op_->UpdateInputName(temp_op_desc->GetAllInputName())) {
GELOGW("Verify UpdateInputName failed");
}
if (!op_->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
GELOGW("Verify UpdateOutputName failed");
}
}
node_op.BreakConnect();
}
GE_IF_BOOL_EXEC(is_unknown_graph, return GRAPH_SUCCESS;);
if (op_->CommonVerify() == GRAPH_SUCCESS) {
Operator op_proxy = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this());
auto verify_func = op_->GetVerifyFunc();
if (verify_func == nullptr) {
verify_func = OperatorFactoryImpl::GetVerifyFunc(GetType());
}
if (verify_func != nullptr) {
return (graphStatus)verify_func(op_proxy);
}
return GRAPH_SUCCESS;
} else {
GELOGE(GRAPH_FAILED, "%s Verify failed.", op_->GetType().c_str());
return GRAPH_FAILED;
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr Node::GetOpDesc() const { return op_; }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(const OpDescPtr &op_desc) {
GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr");
GE_CHK_BOOL_EXEC(op_desc != nullptr, return GRAPH_PARAM_INVALID, "Param OpDesc is nullptr");
GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID,
"Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(),
op_desc->GetInputsSize());

GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID,
"Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(),
op_desc->GetOutputsSize());
op_ = op_desc;
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<std::pair<NodePtr, OutDataAnchorPtr>>
Node::GetInDataNodesAndAnchors() const {
std::vector<std::pair<NodePtr, OutDataAnchorPtr>> vec;
for (const auto &p : in_data_anchors_) {
if (p == nullptr) {
GELOGW("indata anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str());
continue;
}
auto anchor_ptr = p->GetPeerOutAnchor();
if (anchor_ptr == nullptr) {
continue;
}
auto node = anchor_ptr->GetOwnerNode();
if (node == nullptr) {
GELOGW("src node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str());
continue;
}
vec.push_back(std::make_pair(node, anchor_ptr));
}
return Node::Vistor<std::pair<NodePtr, OutDataAnchorPtr>>(shared_from_this(), vec);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<std::pair<NodePtr, InDataAnchorPtr>>
Node::GetOutDataNodesAndAnchors() const {
std::vector<std::pair<NodePtr, InDataAnchorPtr>> vec;
for (const auto &p : out_data_anchors_) {
if (p == nullptr) {
GELOGW("out data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str());
continue;
}
for (const auto &in_anchor : p->GetPeerInDataAnchors()) {
if (in_anchor == nullptr) {
GELOGW("dst in data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str());
continue;
}
auto node = in_anchor->GetOwnerNode();
if (node == nullptr) {
GELOGW("dst node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str());
continue;
}
vec.push_back(std::make_pair(node, in_anchor));
}
}
return Node::Vistor<std::pair<NodePtr, InDataAnchorPtr>>(shared_from_this(), vec);
}
} // namespace ge

+ 0
- 1370
metadef/graph/op_desc.cc
File diff suppressed because it is too large
View File


+ 0
- 79
metadef/graph/op_imp.cc View File

@@ -1,79 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <stdint.h>
#include <functional>
#include <vector>
#include "debug/ge_log.h"
#include "debug/ge_util.h"

using namespace std;

namespace ge {

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
BroadCastInfer(const function<vector<int64_t>()>& get_in1_shape, const function<vector<int64_t>()>& get_in2_shape,
const function<void(const vector<int64_t>& outShape)>& set_out_shape) {
auto x1_shape = get_in1_shape();
auto x2_shape = get_in2_shape();
vector<int64_t> y_shape;

if (x1_shape.empty()) {
y_shape = x2_shape;
set_out_shape(y_shape);
return GRAPH_SUCCESS;
}
if (x2_shape.empty()) {
y_shape = x1_shape;
set_out_shape(y_shape);
return GRAPH_SUCCESS;
}

int len_diff = static_cast<int>(x1_shape.size() - x2_shape.size());
if (len_diff >= 0) {
for (int i = 0; i < len_diff; i++) {
y_shape.push_back(x1_shape[i]);
}
int x2_shape_size = static_cast<int>(x2_shape.size());
for (int i = 0; i < x2_shape_size; i++) {
bool shapeFlag =
((x1_shape[i + len_diff] != x2_shape[i]) && (std::min(x1_shape[i + len_diff], x2_shape[i]) != 1));
if (shapeFlag) {
GE_LOGE("operands could not be broadcast together");
return GRAPH_FAILED;
}
y_shape.push_back(std::max(x1_shape[i + len_diff], x2_shape[i]));
}
} else {
for (int i = 0; i < -len_diff; i++) {
y_shape.push_back(x2_shape[i]);
}
int x1_shape_size = static_cast<int>(x1_shape.size());
for (int i = 0; i < x1_shape_size; i++) {
bool shapeFlag =
((x1_shape[i] != x2_shape[i - len_diff]) && (std::min(x1_shape[i], x2_shape[i - len_diff]) != 1));
if (shapeFlag) {
GE_LOGE("operands could not be broadcast together");
return GRAPH_FAILED;
}
y_shape.push_back(std::max(x1_shape[i], x2_shape[i - len_diff]));
}
}
set_out_shape(y_shape);
return GRAPH_SUCCESS;
}

} // namespace ge

+ 0
- 1587
metadef/graph/operator.cc
File diff suppressed because it is too large
View File


+ 0
- 48
metadef/graph/operator_factory.cc View File

@@ -1,48 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/operator_factory_impl.h"
#include "debug/ge_log.h"

namespace ge {
Operator OperatorFactory::CreateOperator(const std::string &operator_name, const std::string &operator_type) {
return OperatorFactoryImpl::CreateOperator(operator_name, operator_type);
}

graphStatus OperatorFactory::GetOpsTypeList(std::vector<std::string> &all_ops) {
return OperatorFactoryImpl::GetOpsTypeList(all_ops);
}

bool OperatorFactory::IsExistOp(const string &operator_type) { return OperatorFactoryImpl::IsExistOp(operator_type); }

OperatorCreatorRegister::OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator) {
(void)OperatorFactoryImpl::RegisterOperatorCreator(operator_type, op_creator);
}

InferShapeFuncRegister::InferShapeFuncRegister(const std::string &operator_type,
const InferShapeFunc &infer_shape_func) {
(void)OperatorFactoryImpl::RegisterInferShapeFunc(operator_type, infer_shape_func);
}

InferFormatFuncRegister::InferFormatFuncRegister(const std::string &operator_type,
const InferFormatFunc &infer_format_func) {
(void)OperatorFactoryImpl::RegisterInferFormatFunc(operator_type, infer_format_func);
}

VerifyFuncRegister::VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func) {
(void)OperatorFactoryImpl::RegisterVerifyFunc(operator_type, verify_func);
}
} // namespace ge

+ 0
- 149
metadef/graph/operator_factory_impl.cc View File

@@ -1,149 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/operator_factory_impl.h"
#include "debug/ge_log.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
shared_ptr<std::map<string, OpCreator>> OperatorFactoryImpl::operator_creators_;
shared_ptr<std::map<string, InferShapeFunc>> OperatorFactoryImpl::operator_infershape_funcs_;
shared_ptr<std::map<string, InferFormatFunc>> OperatorFactoryImpl::operator_inferformat_funcs_;
shared_ptr<std::map<string, VerifyFunc>> OperatorFactoryImpl::operator_verify_funcs_;

Operator OperatorFactoryImpl::CreateOperator(const std::string &operator_name, const std::string &operator_type) {
if (operator_creators_ == nullptr) {
return Operator();
}
auto it = operator_creators_->find(operator_type);
if (it == operator_creators_->end()) {
GELOGW("no OpProto of [%s] registered", operator_type.c_str());
return Operator();
}
return it->second(operator_name);
}

graphStatus OperatorFactoryImpl::GetOpsTypeList(std::vector<std::string> &all_ops) {
all_ops.clear();
if (operator_creators_ != nullptr) {
for (auto it = operator_creators_->begin(); it != operator_creators_->end(); ++it) {
all_ops.emplace_back(it->first);
}
} else {
GELOGE(GRAPH_FAILED, "no operator creators found");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}

bool OperatorFactoryImpl::IsExistOp(const string &operator_type) {
if (operator_creators_ == nullptr) {
return false;
}
auto it = operator_creators_->find(operator_type);
if (it == operator_creators_->end()) {
return false;
}
return true;
}

InferShapeFunc OperatorFactoryImpl::GetInferShapeFunc(const std::string &operator_type) {
if (operator_infershape_funcs_ == nullptr) {
return nullptr;
}
auto it = operator_infershape_funcs_->find(operator_type);
if (it == operator_infershape_funcs_->end()) {
return nullptr;
}
return it->second;
}

InferFormatFunc OperatorFactoryImpl::GetInferFormatFunc(const std::string &operator_type) {
if (operator_inferformat_funcs_ == nullptr) {
GELOGI("operator_inferformat_funcs_ is null");
return nullptr;
}
auto it = operator_inferformat_funcs_->find(operator_type);
if (it == operator_inferformat_funcs_->end()) {
return nullptr;
}
return it->second;
}

VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) {
if (operator_verify_funcs_ == nullptr) {
return nullptr;
}
auto it = operator_verify_funcs_->find(operator_type);
if (it == operator_verify_funcs_->end()) {
return nullptr;
}
return it->second;
}

graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) {
if (operator_creators_ == nullptr) {
operator_creators_.reset(new (std::nothrow) std::map<string, OpCreator>());
}
auto it = operator_creators_->find(operator_type);
if (it != operator_creators_->end()) {
return GRAPH_FAILED;
}
(void)operator_creators_->emplace(operator_type, op_creator);
return GRAPH_SUCCESS;
}

graphStatus OperatorFactoryImpl::RegisterInferShapeFunc(const std::string &operator_type,
InferShapeFunc const infer_shape_func) {
if (operator_infershape_funcs_ == nullptr) {
GELOGI("operator_infershape_funcs_ init");
operator_infershape_funcs_.reset(new (std::nothrow) std::map<string, InferShapeFunc>());
}
auto it = operator_infershape_funcs_->find(operator_type);
if (it != operator_infershape_funcs_->end()) {
return GRAPH_FAILED;
}
(void)operator_infershape_funcs_->emplace(operator_type, infer_shape_func);
return GRAPH_SUCCESS;
}

graphStatus OperatorFactoryImpl::RegisterInferFormatFunc(const std::string &operator_type,
InferFormatFunc const infer_format_func) {
if (operator_inferformat_funcs_ == nullptr) {
GELOGI("operator_inferformat_funcs_ init");
operator_inferformat_funcs_.reset(new (std::nothrow) std::map<string, InferFormatFunc>());
}
auto it = operator_inferformat_funcs_->find(operator_type);
if (it != operator_inferformat_funcs_->end()) {
return GRAPH_FAILED;
}
(void)operator_inferformat_funcs_->emplace(operator_type, infer_format_func);
return GRAPH_SUCCESS;
}

graphStatus OperatorFactoryImpl::RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func) {
if (operator_verify_funcs_ == nullptr) {
GELOGI("operator_verify_funcs_ init");
operator_verify_funcs_.reset(new (std::nothrow) std::map<string, VerifyFunc>());
}
auto it = operator_verify_funcs_->find(operator_type);
if (it != operator_verify_funcs_->end()) {
return GRAPH_FAILED;
}
(void)operator_verify_funcs_->emplace(operator_type, verify_func);
return GRAPH_SUCCESS;
}
} // namespace ge

+ 0
- 187
metadef/graph/opsproto/opsproto_manager.cc View File

@@ -1,187 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/opsproto_manager.h"
#include <cstdlib>
#include <algorithm>
#include <functional>
#include <iostream>
#include <sstream>
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_log.h"

namespace ge {
OpsProtoManager *OpsProtoManager::Instance() {
static OpsProtoManager instance;
return &instance;
}

bool OpsProtoManager::Initialize(const std::map<std::string, std::string> &options) {
std::lock_guard<std::mutex> lock(mutex_);

if (is_init_) {
GELOGI("OpsProtoManager is already initialized.");
return true;
}

/*lint -e1561*/
auto proto_iter = options.find("ge.opsProtoLibPath");
/*lint +e1561*/
if (proto_iter == options.end()) {
GELOGW("ge.opsProtoLibPath option not set, return.");
return false;
}

pluginPath_ = proto_iter->second;
LoadOpsProtoPluginSo(pluginPath_);

is_init_ = true;

return true;
}

void OpsProtoManager::Finalize() {
std::lock_guard<std::mutex> lock(mutex_);

if (!is_init_) {
GELOGI("OpsProtoManager is not initialized.");
return;
}

for (auto handle : handles_) {
if (handle != nullptr) {
if (dlclose(handle) != 0) {
GELOGW("failed to close handle, message: %s", dlerror());
continue;
}
GELOGI("close opsprotomanager handler success");
} else {
GELOGW("close opsprotomanager handler failure, handler is nullptr");
}
}

is_init_ = false;
}

static std::vector<std::string> Split(const std::string &str, char delim) {
std::vector<std::string> elems;
if (str.empty()) {
elems.emplace_back("");
return elems;
}

std::stringstream ss(str);
std::string item;

while (getline(ss, item, delim)) {
elems.push_back(item);
}

auto str_size = str.size();
if (str_size > 0 && str[str_size - 1] == delim) {
elems.emplace_back("");
}

return elems;
}

static void FindParserSo(const std::string &path, std::vector<std::string> &file_list) {
// Lib plugin path not exist
if (path.empty()) {
GELOGI("realPath is empty");
return;
}
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path.size() >= PATH_MAX, return, "path is invalid");

char resolved_path[PATH_MAX] = {0};

// Nullptr is returned when the path does not exist or there is no permission
// Return absolute path when path is accessible
if (realpath(path.c_str(), resolved_path) == nullptr) {
GELOGW("the path [%s] not exsit.", path.c_str());
return;
}

struct dirent *dent = nullptr;
DIR *dir = opendir(resolved_path);
// Lib plugin path not exist
if (dir == nullptr) {
GELOGW("Open directory %s failed,maybe it is not exit or not a dir", resolved_path);
return;
}

while ((dent = readdir(dir)) != nullptr) {
if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) {
continue;
}
std::string name = dent->d_name;
std::string full_name = path + "/" + name;
const std::string so_suff = ".so";

if (dent->d_type != DT_DIR && name.size() >= so_suff.size() &&
name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) {
file_list.push_back(full_name);
GELOGI("OpsProtoManager Parse full name = %s \n", full_name.c_str());
}
}
if (closedir(dir) != 0) {
GELOGW("close dir fail.");
}
}

static void GetPluginSoFileList(const std::string &path, std::vector<std::string> &file_list) {
// Support multi lib directory with ":" as delimiter
std::vector<std::string> v_path = Split(path, ':');

for (size_t i = 0; i < v_path.size(); ++i) {
FindParserSo(v_path[i], file_list);
GELOGI("OpsProtoManager full name = %s", v_path[i].c_str());
}
}

void OpsProtoManager::LoadOpsProtoPluginSo(std::string &path) {
if (path.empty()) {
GELOGE(GRAPH_FAILED, "filePath is invalid. please check your text file %s.", path.c_str());
return;
}
std::vector<std::string> file_list;

// If there is .so file in the lib path
GetPluginSoFileList(path, file_list);

// Not found any .so file in the lib path
if (file_list.empty()) {
GELOGE(GRAPH_FAILED, "OpsProtoManager can not find any plugin file in pluginPath: %s \n", path.c_str());
return;
}
// Warning message
GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted.");

// Load .so file
for (auto elem : file_list) {
void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle == nullptr) {
GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror());
continue;
} else {
// Close dl when the program exist, not close here
GELOGI("OpsProtoManager plugin load %s success.", elem.c_str());
handles_.push_back(handle);
}
}
}
} // namespace ge

+ 0
- 104
metadef/graph/option/ge_context.cc View File

@@ -1,104 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "./ge_context.h"
#include "./ge_global_options.h"
#include "./ge_local_context.h"
#include "framework/common/ge_types.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
namespace {
const int64_t kMinTrainingTraceJobId = 256;
const int kDecimal = 10;
const char *kHostExecPlacement = "HOST";
} // namespace
GEContext &GetContext() {
static GEContext ge_context{};
return ge_context;
}

graphStatus GEContext::GetOption(const std::string &key, std::string &option) {
return GetThreadLocalContext().GetOption(key, option);
}

bool GEContext::GetHostExecFlag() {
std::string exec_placement;
if (GetThreadLocalContext().GetOption(GE_OPTION_EXEC_PLACEMENT, exec_placement) != GRAPH_SUCCESS) {
GELOGW("get option OPTION_EXEC_PLACEMENT failed.");
return false;
}
GELOGD("Option ge.exec.placement is %s.", exec_placement.c_str());
return exec_placement == kHostExecPlacement;
}

std::map<std::string, std::string> &GetMutableGlobalOptions() {
static std::map<std::string, std::string> global_options{};
return global_options;
}

void GEContext::Init() {
string session_id;
(void)GetOption("ge.exec.sessionId", session_id);
try {
session_id_ = static_cast<uint64_t>(std::stoi(session_id.c_str()));
} catch (std::invalid_argument &) {
GELOGW("%s transform to int failed.", session_id.c_str());
} catch (std::out_of_range &) {
GELOGW("%s transform to int failed.", session_id.c_str());
}

string device_id;
(void)GetOption("ge.exec.deviceId", device_id);
try {
device_id_ = static_cast<uint32_t>(std::stoi(device_id.c_str()));
} catch (std::invalid_argument &) {
GELOGW("%s transform to int failed.", device_id.c_str());
} catch (std::out_of_range &) {
GELOGW("%s transform to int failed.", device_id.c_str());
}

string job_id;
(void)GetOption("ge.exec.jobId", job_id);
std::string s_job_id = "";
for (auto c : job_id) {
if (c >= '0' && c <= '9') {
s_job_id += c;
}
}
if (s_job_id == "") {
trace_id_ = kMinTrainingTraceJobId;
return;
}
int64_t d_job_id = std::strtoll(s_job_id.c_str(), nullptr, kDecimal);
if (d_job_id < kMinTrainingTraceJobId) {
trace_id_ = d_job_id + kMinTrainingTraceJobId;
} else {
trace_id_ = d_job_id;
}
}

uint64_t GEContext::SessionId() { return session_id_; }

uint32_t GEContext::DeviceId() { return device_id_; }

uint64_t GEContext::TraceId() { return trace_id_; }

void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; }

void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; }

} // namespace ge

+ 0
- 60
metadef/graph/option/ge_local_context.cc View File

@@ -1,60 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "./ge_local_context.h"
#include <utility>

namespace ge {
namespace {
thread_local GEThreadLocalContext thread_context;
}

GEThreadLocalContext &GetThreadLocalContext() { return thread_context; }

graphStatus GEThreadLocalContext::GetOption(const string &key, string &option) {
auto graph_iter = graph_options_.find(key);
if (graph_iter != graph_options_.end()) {
option = graph_iter->second;
return GRAPH_SUCCESS;
}
auto session_iter = session_options_.find(key);
if (session_iter != session_options_.end()) {
option = session_iter->second;
return GRAPH_SUCCESS;
}
auto global_iter = global_options_.find(key);
if (global_iter != global_options_.end()) {
option = global_iter->second;
return GRAPH_SUCCESS;
}
return GRAPH_PARAM_INVALID;
}

void GEThreadLocalContext::SetGlobalOption(map<string, string> options_map) {
global_options_.clear();
global_options_ = std::move(options_map);
}

void GEThreadLocalContext::SetSessionOption(map<string, string> options_map) {
session_options_.clear();
session_options_ = std::move(options_map);
}

void GEThreadLocalContext::SetGraphOption(map<std::string, string> options_map) {
graph_options_.clear();
graph_options_ = std::move(options_map);
}
} // namespace ge

+ 0
- 455
metadef/graph/ref_relation.cc View File

@@ -1,455 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/ref_relation.h"

#include <unordered_set>
#include <unordered_map>

#include "utils/mem_utils.h"
#include "debug/ge_log.h"
#include "debug/ge_op_types.h"
#include "debug/ge_util.h"
#include "debug/ge_attr_define.h"
#include "graph/ge_error_codes.h"
#include "graph/utils/graph_utils.h"
#include "framework/common/debug/ge_log.h"

using namespace std;
using namespace ge;
namespace ge {
namespace {
const char *kRefIndex = "_parent_node_index";
const string kWhile = "While";
const string kIf = "If";
const string kCase = "Case";

const uint16_t kMaxElementNum = 100;

std::unordered_set<string> function_op = {kWhile, kIf, kCase};
} // namespace

/* Impl */
class RefRelations::Impl {
public:
graphStatus LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) {
unsigned long number = static_cast<unsigned long>(reinterpret_cast<uintptr_t>(key.node.get()));
std::string lookup_key =
key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx) + std::to_string(number);
auto iter = look_up_table_.find(lookup_key);
if (iter != look_up_table_.end()) {
for (auto &c : iter->second) {
result.insert(c);
}
return GRAPH_SUCCESS;
}
GELOGW("can not find any relations! key value is %s", lookup_key.c_str());
return GRAPH_SUCCESS;
};
graphStatus BuildRefRelations(ge::ComputeGraph &root_graph);
graphStatus Clear() {
GELOGD("Start clear boundary reflections between main graph and sub graph!");
look_up_table_.clear();
values_.clear();
return GRAPH_SUCCESS;
};

private:
graphStatus BuildLookUpTables();
graphStatus BuildRefRelationsForBranch(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes,
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
vector<vector<RefCell>> &node_refs);
graphStatus BuildRefRelationsForWhile(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes,
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
vector<vector<RefCell>> &node_refs);
graphStatus BuildRelationsWithFuncNodeType(const NodePtr &root_node,
const vector<vector<NodePtr>> &classed_data_nodes,
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes,
vector<vector<RefCell>> &node_refs);
void GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes,
vector<NodePtr> &netoutput_nodes, const std::vector<std::string> &sub_graph_names,
const std::string &node_type);

graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph);
graphStatus ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, vector<vector<NodePtr>> &classed_data_nodes);
graphStatus ProcessSubgraphNetoutput(const vector<NodePtr> &netoutput_nodes,
vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes);

std::unordered_map<string, vector<RefCell>> look_up_table_;
std::vector<vector<vector<RefCell>>> values_;
};

// Node Level
graphStatus RefRelations::Impl::BuildRefRelationsForBranch(
const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes,
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) {
GELOGD("Enter BuildRefRelationsForBranch!");

size_t ref_i = 0;
for (const auto &ref_i_data_nodes : classed_data_nodes) {
vector<RefCell> in_ref_i_all_refs;
RefCell cell_root;
cell_root.node_name = root_node->GetName();
cell_root.node = root_node;
cell_root.in_out = NODE_IN;
cell_root.in_out_idx = ref_i;
in_ref_i_all_refs.emplace_back(cell_root);
for (const auto &data : ref_i_data_nodes) {
RefCell cell_in;
RefCell cell_out;
cell_in.node_name = data->GetName();
cell_in.node = data;
cell_in.in_out = NODE_IN;
cell_in.in_out_idx = 0;
cell_out.node_name = data->GetName();
cell_out.node = data;
cell_out.in_out = NODE_OUT;
cell_out.in_out_idx = 0;
in_ref_i_all_refs.emplace_back(cell_in);
in_ref_i_all_refs.emplace_back(cell_out);
}
node_refs.emplace_back(in_ref_i_all_refs);
ref_i++;
}

size_t ref_o = 0;
for (const auto &ref_o_net_nodes : classed_netoutput_nodes) {
vector<RefCell> out_ref_i_all_refs;
RefCell cell_root;
cell_root.node_name = root_node->GetName();
cell_root.node = root_node;
cell_root.in_out = NODE_OUT;
cell_root.in_out_idx = ref_o;
out_ref_i_all_refs.emplace_back(cell_root);
for (const auto &ele : ref_o_net_nodes) {
RefCell cell_netoutput_in;
cell_netoutput_in.node_name = (ele.first)->GetName();
cell_netoutput_in.node = ele.first;
cell_netoutput_in.in_out = NODE_IN;
cell_netoutput_in.in_out_idx = ele.second;
out_ref_i_all_refs.emplace_back(cell_netoutput_in);
}
node_refs.emplace_back(out_ref_i_all_refs);
ref_o++;
}
return GRAPH_SUCCESS;
}

graphStatus RefRelations::Impl::BuildLookUpTables() {
GELOGD("start to build look up table!");
for (size_t i = 0; i < values_.size(); i++) {
vector<vector<RefCell>> &val = values_[i];
for (const auto &ele : val) {
for (const auto &ref_cell : ele) {
string key = ref_cell.node_name + std::to_string(ref_cell.in_out) + std::to_string(ref_cell.in_out_idx) +
std::to_string(static_cast<unsigned long>(reinterpret_cast<uintptr_t>(ref_cell.node.get())));
look_up_table_[key] = ele;
}
}
}
return GRAPH_SUCCESS;
}

graphStatus RefRelations::Impl::BuildRefRelationsForWhile(
const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes,
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) {
GELOGD("Enter BuildRefRelations for while op!");
// data_nodes has been sorted
// for while, input num must be same as output num
auto input_num = root_node->GetAllInDataAnchorsSize();
NodePtr netoutput = nullptr;

size_t ref_i = 0;
while (ref_i < input_num) {
auto &ref_i_data_nodes = classed_data_nodes[ref_i];
auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i];

vector<RefCell> ref_i_all_refs;
RefCell cell_root_i;
RefCell cell_root_o;
cell_root_i.node_name = root_node->GetName();
cell_root_i.node = root_node;
cell_root_i.in_out = NODE_IN;
cell_root_i.in_out_idx = ref_i;
ref_i_all_refs.emplace_back(cell_root_i);
cell_root_o.node_name = root_node->GetName();
cell_root_o.node = root_node;
cell_root_o.in_out = NODE_OUT;
cell_root_o.in_out_idx = ref_i;
ref_i_all_refs.emplace_back(cell_root_o);
for (const auto &data : ref_i_data_nodes) {
RefCell cell_in;
RefCell cell_out;
cell_in.node_name = data->GetName();
cell_in.node = data;
cell_in.in_out = NODE_IN;
cell_in.in_out_idx = 0;
cell_out.node_name = data->GetName();
cell_out.node = data;
cell_out.in_out = NODE_OUT;
cell_out.in_out_idx = 0;
ref_i_all_refs.emplace_back(cell_in);
ref_i_all_refs.emplace_back(cell_out);
}

for (const auto &ele : ref_i_net_nodes) {
RefCell cell_netoutput_in;
RefCell cell_netoutput_out;
cell_netoutput_in.node_name = (ele.first)->GetName();
cell_netoutput_in.node = ele.first;
cell_netoutput_in.in_out = NODE_IN;
cell_netoutput_in.in_out_idx = ele.second;
ref_i_all_refs.emplace_back(cell_netoutput_in);
netoutput = ele.first;
}
node_refs.emplace_back(ref_i_all_refs);
ref_i++;
}
/* There exist scene like the follows, it means data0 data1 netoutput 0'th
* and 1'th tensor should be the same addr.
* Data0 Data1
* \/
* /\
* netoutput
*/
if (netoutput == nullptr) {
return GRAPH_SUCCESS;
}
for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) {
auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
if (peer_out_data_anchor == nullptr) {
continue;
}
auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (netoutput->GetName()).c_str());
continue;
}
if (peer_out_data_node->GetType() != DATA) {
continue;
}
auto in_data_anchor_idx = in_anchor->GetIdx();
auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx));
int ref_d = 0;
int ref_n = 0;
(void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d);
(void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n);

node_refs[ref_d].insert(node_refs[ref_d].end(), node_refs[ref_n].begin(), node_refs[ref_n].end());
node_refs[ref_n].insert(node_refs[ref_n].end(), node_refs[ref_d].begin(), node_refs[ref_d].end());
}

return GRAPH_SUCCESS;
}
// build ref relations according to diff func op type
graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType(
const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes,
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) {
// data_nodes has been sorted
auto node_type = root_node->GetType();

auto status = GRAPH_SUCCESS;
if (node_type != kWhile) {
status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs);
} else {
status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs);
}
return status;
}

void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes,
vector<NodePtr> &netoutput_nodes,
const std::vector<std::string> &sub_graph_names,
const std::string &node_type) {
int sub_graph_idx = 0;
for (const auto &name : sub_graph_names) {
auto sub_graph = root_graph.GetSubgraph(name);
if (sub_graph == nullptr) {
GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str());
continue;
}
for (const auto &sub_graph_node : sub_graph->GetDirectNode()) {
auto sub_graph_node_type = sub_graph_node->GetType();

if (sub_graph_node_type == DATA) {
data_nodes.emplace_back(sub_graph_node);
} else if (sub_graph_node_type == NETOUTPUT) {
// if while, the first subgraph must be cond subgraph.
// There is no meaning for refs ,so continue
if (node_type == kWhile && sub_graph_idx == 0) {
continue;
}
netoutput_nodes.emplace_back(sub_graph_node);
}
continue;
}
sub_graph_idx++;
}
}

graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) {
auto parent_graph_ptr = graph.GetParentGraph();
if (parent_graph_ptr == nullptr) {
root_graph = graph;
return GRAPH_SUCCESS;
}
auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr);
if (root_graph_ptr == nullptr) {
GE_LOGE("Get null root graph");
return GRAPH_PARAM_INVALID;
}
root_graph = *root_graph_ptr;
return GRAPH_SUCCESS;
}

graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes,
vector<vector<NodePtr>> &classed_data_nodes) {
GELOGD("start to process subgraph data nodes!");
int max_ref_idx = 0;
for (const auto &e : data_nodes) {
int i;
bool is_exist = true;
is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i);
if (!is_exist) {
GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIndex);
return GRAPH_FAILED;
}
max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx;
}

while (!data_nodes.empty()) {
auto data = data_nodes.back();
data_nodes.pop_back();
int ref_idx = 0;
(void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx);
if (ref_idx >= static_cast<int>(classed_data_nodes.size())) {
return GRAPH_FAILED;
}
classed_data_nodes[ref_idx].emplace_back(data);
}
return GRAPH_SUCCESS;
}

graphStatus RefRelations::Impl::ProcessSubgraphNetoutput(
const vector<NodePtr> &netoutput_nodes, vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes) {
GELOGD("[RefRelations]Start to process subgraph netoutput!");
for (const auto &sub_netoutput_node : netoutput_nodes) {
auto op_desc = sub_netoutput_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);

for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) {
auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx());
if (in_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it",
sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx());
return GRAPH_FAILED;
}
int ref_o;
if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) {
if (ref_o >= static_cast<int>(classed_netoutput_nodes.size())) {
return GRAPH_FAILED;
}
classed_netoutput_nodes[ref_o].emplace_back(
std::pair<NodePtr, size_t>({sub_netoutput_node, static_cast<size_t>(in_data_anchor->GetIdx())}));
}
}
}
return GRAPH_SUCCESS;
}

graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) {
GELOGD("Start to build ref relations!");
/* First Step: Get root graph */
ge::ComputeGraph &root_graph = graph;
auto status = GetRootGraph(graph, root_graph);
if (status != GRAPH_SUCCESS) {
return status;
}

for (const auto &node : graph.GetAllNodes()) {
auto node_type = node->GetType();
std::vector<NodePtr> ref_nodes;
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
continue;
}
vector<NodePtr> data_nodes;
vector<NodePtr> netoutput_nodes;
// Get data and netoutput of sub_graph
GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type);
size_t max_elem_num = (data_nodes.size() > kMaxElementNum) ? data_nodes.size() : kMaxElementNum;
vector<vector<NodePtr>> classed_data_nodes(max_elem_num); // according to ref_idx
vector<vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes(max_elem_num); // according to ref_idx
status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "classfy data nodes failed!");
return status;
}

// for netoutput
// check netoutput
// here main graph output number must be the same as every sub_graph netoutput node
// key: netoutput node_ptr ,<ref_idx, net_in_idx>
status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "process netoutput failed!");
return status;
}

vector<vector<RefCell>> node_refs;
status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs);
if (status != GRAPH_SUCCESS) {
GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str());
return status;
}
if (!node_refs.empty()) {
values_.push_back(node_refs);
}
}
/* Seconde Step: generate map */
status = BuildLookUpTables();
if (status != GRAPH_SUCCESS) {
GELOGE(status, "Build look up tables failed!");
return status;
}
return GRAPH_SUCCESS;
}

/* Ref Relations Interface */
RefRelations::RefRelations() {
impl_ = MakeShared<Impl>();
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "MakeShared failed!");
return;
}
}

graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) {
GE_CHECK_NOTNULL(impl_);
return impl_->LookUpRefRelations(key, result);
}

graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) {
GE_CHECK_NOTNULL(impl_);
return impl_->BuildRefRelations(root_graph);
}

graphStatus RefRelations::Clear() {
GE_CHECK_NOTNULL(impl_);
return impl_->Clear();
}
} // namespace ge

+ 0
- 96
metadef/graph/runtime_inference_context.cc View File

@@ -1,96 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/runtime_inference_context.h"
#include <cstdint>
#include "framework/common/debug/ge_log.h"

namespace ge {
std::map<std::string, std::unique_ptr<RuntimeInferenceContext>> RuntimeInferenceContext::contexts_;
std::mutex RuntimeInferenceContext::ctx_mu_;

graphStatus RuntimeInferenceContext::CreateContext(const std::string &context_id) {
GELOGI("To create context. session id = %s", context_id.c_str());
auto ctx = std::unique_ptr<RuntimeInferenceContext>(new (std::nothrow) RuntimeInferenceContext());
if (ctx == nullptr) {
GELOGE(GRAPH_FAILED, "Failed to create instance of RuntimeInferenceContext. context_id = %s", context_id.c_str());
return GRAPH_FAILED;
}

std::lock_guard<std::mutex> lk(ctx_mu_);
auto emplace_ret = contexts_.emplace(context_id, std::move(ctx));
if (!emplace_ret.second) {
GELOGE(GRAPH_FAILED, "Old context not destroyed");
return GRAPH_FAILED;
}

return GRAPH_SUCCESS;
}

void RuntimeInferenceContext::DestroyContext(const std::string &context_id) {
GELOGI("To destroy context. session id = %s", context_id.c_str());
std::lock_guard<std::mutex> lk(ctx_mu_);
contexts_.erase(context_id);
}

graphStatus RuntimeInferenceContext::GetContext(const std::string &context_id, RuntimeInferenceContext **ctx) {
std::lock_guard<std::mutex> lk(ctx_mu_);
auto it = contexts_.find(context_id);
if (it != contexts_.end()) {
*ctx = it->second.get();
return GRAPH_SUCCESS;
}

GELOGD("Runtime inference context not created. session id = %s", context_id.c_str());
return GRAPH_FAILED;
}

graphStatus RuntimeInferenceContext::SetTensor(int64_t node_id, int output_id, Tensor &&tensor) {
std::lock_guard<std::mutex> lk(mu_);
auto &output_tensors = tensors_[node_id];
if (static_cast<uint32_t>(output_id) >= output_tensors.size()) {
output_tensors.resize(output_id + 1);
}

GELOGD("Set tensor for node_id = %ld, output_id = %d", node_id, output_id);
output_tensors[output_id] = std::move(tensor);
return GRAPH_SUCCESS;
}

graphStatus RuntimeInferenceContext::GetTensor(int64_t node_id, int output_id, Tensor &tensor) {
if (output_id < 0) {
GELOGE(GRAPH_PARAM_INVALID, "Invalid output index: %d", output_id);
return GRAPH_PARAM_INVALID;
}

std::lock_guard<std::mutex> lk(mu_);
auto iter = tensors_.find(node_id);
if (iter == tensors_.end()) {
GELOGE(INTERNAL_ERROR, "Node not register. Id = %ld", node_id);
return INTERNAL_ERROR;
}

auto &output_tensors = iter->second;
if (static_cast<uint32_t>(output_id) >= output_tensors.size()) {
GELOGE(GRAPH_FAILED, "Node output is not registered. node_id = %ld, output index = %d", node_id, output_id);
return GRAPH_FAILED;
}

GELOGD("Get tensor for node_id = %ld, output_id = %d", node_id, output_id);
tensor = output_tensors[output_id];
return GRAPH_SUCCESS;
}
} // namespace ge

+ 0
- 688
metadef/graph/shape_refiner.cc View File

@@ -1,688 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/shape_refiner.h"

#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"

#include "debug/ge_log.h"
#include "debug/ge_op_types.h"
#include "external/graph/operator.h"
#include "external/graph/operator_factory.h"
#include "framework/common/debug/ge_log.h"
#include "graph/compute_graph.h"
#include "utils/node_utils.h"
#include "utils/op_desc_utils.h"
#include "utils/tensor_utils.h"
#include "utils/type_utils.h"

namespace ge {
namespace {
const uint32_t kWhileBodySubGraphIdx = 1;

graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) {
GELOGD("Enter reverse brush while body subgraph process!");

auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx);
if (sub_graph_body == nullptr) {
GELOGE(GRAPH_FAILED, "Get while body graph failed!");
return GRAPH_FAILED;
}

for (const auto &node_sub : sub_graph_body->GetAllNodes()) {
for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) {
auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i);
GE_IF_BOOL_EXEC(input_desc == nullptr,
GELOGW("Get null input by index %zu from node %s ", i, node_sub->GetName().c_str());
continue);
(void)input_desc->SetUnknownDimNumShape();
}
for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) {
auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i);
(void)output_desc->SetUnknownDimNumShape();
}
}

return GRAPH_SUCCESS;
}

graphStatus UpdataOutputForMultiBatcch(const ConstNodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
// check sub_graph shape. Get max for update.
for (size_t i = 0; i < ref_out_tensors.size(); ++i) {
if (ref_out_tensors[i].empty()) {
continue;
}

int64_t max_size = 0;
size_t max_shape_index = 0;
auto &ref_out_tensor = ref_out_tensors[i].at(0);
const auto &ref_out_tensor_shape = ref_out_tensor.MutableShape();
for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) {
auto &tensor = ref_out_tensors[i].at(j);
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str());
return GRAPH_FAILED;
}

auto shape = tensor.MutableShape();
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
GELOGE(GRAPH_FAILED, "node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu",
node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
return GRAPH_FAILED;
}

int64_t size = 1;
for (auto dim : shape.GetDims()) {
if (INT64_MAX / dim < size) {
GELOGE(PARAM_INVALID, "The shape size overflow");
return PARAM_INVALID;
}
size *= dim;
}

if (size > max_size) {
max_size = size;
max_shape_index = j;
}
}

(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index));
}

return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
GELOGD("Enter update parent node shape for class branch op process");
if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) {
return UpdataOutputForMultiBatcch(node, ref_out_tensors);
}

// check sub_graph shape.If not same ,do unknown shape process
for (size_t i = 0; i < ref_out_tensors.size(); i++) {
if (ref_out_tensors[i].empty()) {
continue;
}
auto ref_out_tensor = ref_out_tensors[i].at(0);
ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
for (auto &tensor : ref_out_tensors[i]) {
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str());
return GRAPH_FAILED;
}
auto shape = tensor.MutableShape();
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
break;
}
for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
continue;
}
GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
(void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
}
}
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
}
return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
GELOGD("Enter update parent node shape for class while op process");
if (ref_data_tensors.size() != ref_out_tensors.size()) {
GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(),
ref_data_tensors.size(), ref_out_tensors.size());
return GRAPH_FAILED;
}
for (size_t i = 0; i < ref_data_tensors.size(); i++) {
if (ref_out_tensors[i].size() != 1) {
GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!");
return GRAPH_FAILED;
}
}
bool is_need_reverse_brush = false;
// check input and output
for (size_t i = 0; i < ref_out_tensors.size(); i++) {
if (ref_out_tensors[i].empty()) {
continue;
}
auto ref_out_tensor = ref_out_tensors[i].at(0);
auto tmp_shape = ref_out_tensor.MutableShape();
// ref_i's data and output tensor shape should be same
for (auto &tensor : ref_data_tensors[i]) {
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str());
return GRAPH_FAILED;
}
auto shape = tensor.MutableShape();
if (shape.GetDims() != tmp_shape.GetDims()) {
ref_out_tensor.SetUnknownDimNumShape();
is_need_reverse_brush = true;
break;
}
}
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
}
// reverse refresh while body shape
if (is_need_reverse_brush) {
return ReverseBrushWhileBodySubGraph(node);
}
return GRAPH_SUCCESS;
}

graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
return GRAPH_SUCCESS;
}

auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
for (const auto &name : sub_graph_names) {
if (name.empty()) {
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
continue;
}
auto sub_graph = root_graph->GetSubgraph(name);
if (sub_graph == nullptr) {
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
return GRAPH_FAILED;
}
for (const auto &node_sub : sub_graph->GetDirectNode()) {
if (node_sub->GetType() != DATA) {
continue;
}
int ref_i;
auto data_opdesc = node_sub->GetOpDesc();
if (data_opdesc == nullptr) {
GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
node->GetName().c_str());
return GRAPH_FAILED;
}
if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
node->GetName().c_str());
return GRAPH_FAILED;
}
if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) {
continue;
}
auto input_desc = op_desc->MutableInputDesc(ref_i);
if (input_desc == nullptr) {
GE_LOGE(
"The ref index(%d) on the data %s on the sub graph %s "
"parent node %s are incompatible, inputs num %u",
ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize());
return GRAPH_FAILED;
}
GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(),
node->GetName().c_str());
auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);

if (ret != GRAPH_SUCCESS) {
GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s",
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
return ret;
}
ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
if (ret != GRAPH_SUCCESS) {
GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s",
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
return ret;
}
}
}
return GRAPH_SUCCESS;
}

graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph, NodePtr &netoutput,
const ConstNodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
auto sub_nodes = sub_graph->GetDirectNode();
for (size_t i = sub_nodes.size(); i > 0; --i) {
auto sub_node = sub_nodes.at(i - 1);
if (sub_node->GetType() == NETOUTPUT) {
netoutput = sub_node;
}
if (sub_node->GetType() == DATA) {
if (sub_node->GetOpDesc() == nullptr) {
return GRAPH_FAILED;
}

int ref_i;
if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
return GRAPH_FAILED;
}
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(),
ref_i, node->GetAllInDataAnchorsSize());
return GRAPH_FAILED;
}
ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
}
}
return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
return GRAPH_SUCCESS;
}

std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());

for (const auto &name : sub_graph_names) {
if (name.empty()) {
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
continue;
}
auto sub_graph = root_graph->GetSubgraph(name);
if (sub_graph == nullptr) {
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
return GRAPH_FAILED;
}
NodePtr netoutput = nullptr;
auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
if (ret != GRAPH_SUCCESS) {
return ret;
}
if (netoutput == nullptr) {
GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str());
return GRAPH_FAILED;
}
auto netoutput_opdesc = netoutput->GetOpDesc();
if (netoutput_opdesc == nullptr) {
GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(),
node->GetName().c_str());
return GRAPH_FAILED;
}
for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
if (edge_desc == nullptr) {
GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(),
node->GetName().c_str(), edge_anchor->GetIdx());
return GRAPH_FAILED;
}
GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", edge_anchor->GetIdx(),
edge_desc->GetShape().GetDimNum());
int ref_i;
if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
// if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
continue;
}
GELOGI("Parent node index of edge desc is %d", ref_i);
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
return GRAPH_FAILED;
}
ref_out_tensors[ref_i].emplace_back(*edge_desc);
}
}

if (node->GetType() == WHILE) {
return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors);
}
return UpdateParentNodeForBranch(node, ref_out_tensors);
}

string Serial(const vector<int64_t> &dims) {
string serial_string;
serial_string += "[";
for (int64_t dim : dims) {
serial_string += std::to_string(dim) + " ";
}
serial_string += "]";
return serial_string;
}

graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) {
GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED);
for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) {
auto in_idx = in_anchor->GetIdx();
auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
if (peer_out_data_anchor == nullptr) {
continue;
}
auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
continue;
}
int peer_out_idx = peer_out_data_anchor->GetIdx();
auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx));

// check shape and dtype continuity. do not stop process
auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_idx));
if (in_desc == nullptr) {
continue;
}
auto in_shape = in_desc->GetShape().GetDims();
auto in_dtype = in_desc->GetDataType();
auto peer_out_shape = peer_out_desc->GetShape().GetDims();
auto peer_out_dtype = peer_out_desc->GetDataType();
if (peer_out_dtype != in_dtype) {
GELOGW(
"current node [%s] [%d]\'th out_dtype is [%s].peer output node [%s] [%d]\'th "
"output_dtype is [%s].The two dtype should be same! Please check graph and fix it",
node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(),
peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str());
} else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) {
string in_shape_str = Serial(in_shape);
string peer_out_shape_str = Serial(peer_out_shape);
GELOGW(
"current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th "
"input_shape is [%s].The two shape should be same! Please check graph and fix it",
node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx,
peer_out_shape_str.c_str());
}
// refresh current node input desc
in_desc->SetOriginShape(peer_out_desc->GetOriginShape());
in_desc->SetShape(peer_out_desc->GetShape());
in_desc->SetDataType(peer_out_desc->GetDataType());
in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType());
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)peer_out_desc->GetShapeRange(shape_range);
in_desc->SetShapeRange(shape_range);
ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast<uint32_t>(peer_out_desc->GetShape().GetDims().size()));
}
return GRAPH_SUCCESS;
}
} // namespace
void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) {
if (!IsLogEnable(GE, DLOG_DEBUG)) {
return;
}
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "node is null");
return;
}
ge::OpDescPtr op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return );
std::string str;
if (op_desc->GetInputsSize() != 0) {
std::string input_desc_str = "input shape: ";
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
input_desc_str += "[";
for (int64_t dim : input_desc->GetShape().GetDims()) {
input_desc_str += std::to_string(dim) + " ";
}
input_desc_str += "]";
input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) + ":" +
TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " ";
}
str += input_desc_str;

input_desc_str = "input origin shape: ";
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
input_desc_str += "[";
for (int64_t dim : input_desc->GetOriginShape().GetDims()) {
input_desc_str += std::to_string(dim) + " ";
}
input_desc_str += "]";
input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) + ":" +
TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) + " ";
}
str += input_desc_str;
}

if (op_desc->GetAllOutputsDescSize() != 0) {
std::string output_desc_str = "output shape: ";
for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
if (output_desc == nullptr) {
continue;
}
output_desc_str += "[";
for (int64_t dim : output_desc->GetShape().GetDims()) {
output_desc_str += std::to_string(dim) + " ";
}
output_desc_str += "]";
output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) + ":" +
TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " ";
}
str += output_desc_str;

output_desc_str = "output origin shape: ";
for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
if (output_desc == nullptr) {
continue;
}
output_desc_str += "[";
for (int64_t dim : output_desc->GetOriginShape().GetDims()) {
output_desc_str += std::to_string(dim) + " ";
}
output_desc_str += "]";
output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) + ":" +
TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) + " ";
}
str += output_desc_str;
}
GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str());
}

graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) {
return InferShapeAndType(node, op, true);
}
graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) {
auto op_desc = node->GetOpDesc();
const auto &op_type = op_desc->GetType();

graphStatus ret;
if (before_subgraph) {
ret = UpdateSubGraphDataNodes(node);
if (ret != GRAPH_SUCCESS) {
return ret;
}
}
// Get infer func and execute
ret = op_desc->CallInferFunc(op);
if (ret == GRAPH_PARAM_INVALID) {
// Op ir no infer func, try to get infer func from operator factory
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
if (node_op.IsEmpty()) {
GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
return ret;
}

GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str());
auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
node_op.BreakConnect();
if (temp_op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "temp op desc is null");
return GRAPH_FAILED;
}
if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
GELOGW("InferShapeAndType UpdateInputName failed");
for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) {
if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) {
break;
}
return GRAPH_SUCCESS;
}
}
if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
GELOGW("InferShapeAndType UpdateOutputName failed");
}
op_desc->AddInferFunc(temp_op_desc->GetInferFunc());
ret = op_desc->CallInferFunc(op);
GELOGI("op CallInferFunc second. ret: %u", ret);
}
if (ret != GRAPH_SUCCESS) {
return ret;
}

if (!before_subgraph) {
return UpdateParentNodeOutTensor(node);
}
return GRAPH_SUCCESS;
}

InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map,
const NodePtr &node) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "node is null");
return nullptr;
}
InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(InferenceContext::Create());
if (inference_context == nullptr) {
GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext");
return nullptr;
}

auto all_in_data_anchors = node->GetAllInDataAnchors();
std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size());
std::vector<std::string> marks;

bool has_input_shapes_and_types = false;
for (const auto &in_anchor : all_in_data_anchors) {
const auto &out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) {
continue;
}

auto input_node = out_anchor->GetOwnerNode();
if (input_node == nullptr) {
continue;
}

auto iter = context_map.find(input_node);
if (iter != context_map.end()) {
const auto &src_context = iter->second;
GE_IF_BOOL_EXEC(src_context == nullptr, GELOGE(GRAPH_FAILED, "src_context is null."); return nullptr);
GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(),
input_node->GetName().c_str());
for (auto mark : src_context->GetMarks()) {
marks.push_back(mark);
}
auto output_idx = out_anchor->GetIdx();
auto input_idx = in_anchor->GetIdx();
auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes();
if (output_idx < static_cast<int>(output_shape_and_type.size())) {
GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx,
node->GetName().c_str(), input_idx);
input_shapes_and_types[input_idx] = output_shape_and_type[output_idx];
has_input_shapes_and_types = true;
} else {
GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx,
output_shape_and_type.size());
}
}
}

if (has_input_shapes_and_types) {
inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types));
}
inference_context->SetMarks(marks);

return inference_context;
}

namespace {
thread_local std::unordered_map<NodePtr, InferenceContextPtr> context_map;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ShapeRefiner::ClearContextMap() { context_map.clear(); }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) {
return InferShapeAndType(node, true);
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node,
bool before_subgraph) {
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
auto opdesc = node->GetOpDesc();
GE_IF_BOOL_EXEC(opdesc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED);
// some op can not infershape twice such as aipp
bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified");
if (need_update_input) {
auto status = UpdateOpInputDesc(node);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "update op input_desc failed!");
return status;
}
}

if (node->Verify() != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str());
return GRAPH_FAILED;
}
PrintInOutTensorShape(node, "before_infershape");
Operator op = OpDescUtils::CreateOperatorFromNode(node);

if (!is_unknown_graph) {
auto inference_context = CreateInferenceContext(context_map, node);
if (inference_context == nullptr) {
GELOGE(GRAPH_FAILED, "inference context is null");
return GRAPH_FAILED;
}
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
op.SetInferenceContext(inference_context);
}

graphStatus status = InferShapeAndType(node, op, before_subgraph);
if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) {
if (is_unknown_graph) {
PrintInOutTensorShape(node, "after_infershape when running");
return GRAPH_SUCCESS;
}
auto op_desc = node->GetOpDesc();
for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
output_tensor->SetOriginShape(output_tensor->GetShape());
output_tensor->SetOriginDataType(output_tensor->GetDataType());

GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
}
} else {
GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str());
return GRAPH_FAILED;
}
if (!is_unknown_graph) {
auto ctx_after_infer = op.GetInferenceContext();
if (ctx_after_infer != nullptr) {
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
ctx_after_infer->GetMarks().size());
(void)context_map.emplace(node, ctx_after_infer);
}
}
}
PrintInOutTensorShape(node, "after_infershape");

return GRAPH_SUCCESS;
}
} // namespace ge

+ 0
- 704
metadef/graph/tensor.cc View File

@@ -1,704 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "external/graph/tensor.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/ge_tensor.h"
#include "securec.h"
#include "utils/attr_utils.h"
#include "utils/tensor_adapter.h"
#include "utils/tensor_utils.h"
#include "utils/type_utils.h"

namespace {
/// Extra 8 bytes store pointer of string
/// Extra 1 byte store '\0'
const int EXTRA_STORE_POINTER_FOR_STRING = 8;
const int EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL = 9;
const int64_t UNKNOWN_DIM_SIZE = -1;
} // namespace

namespace ge {
// If not overflow return true
static bool Int64MulNotOverflow(int64_t a, int64_t b) {
if (a > 0) {
if (b > 0) {
if (a > (INT64_MAX / b)) {
return false;
}
} else {
if (b < (INT64_MIN / a)) {
return false;
}
}
} else {
if (b > 0) {
if (a < (INT64_MIN / b)) {
return false;
}
} else {
if ((a != 0) && (b < (INT64_MAX / a))) {
return false;
}
}
}
return true;
}

class TensorDescImpl {
public:
TensorDescImpl() = default;
~TensorDescImpl() = default;
TensorDescImpl(const Shape &shape, Format format, DataType dt) : shape_(shape), format_(format), data_type_(dt) {}

Shape shape_;
std::vector<std::pair<int64_t, int64_t>> range_;
Format format_ = FORMAT_ND;
Format origin_format_ = FORMAT_ND;
DataType data_type_ = DT_FLOAT;
Shape origin_shape_;
int64_t size_ = 0;
int64_t real_dim_cnt_ = 0;
std::string name_;
};

class TensorImpl {
public:
TensorImpl() = default;
~TensorImpl() = default;

explicit TensorImpl(const TensorDesc &tensor_desc) : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)) {}
TensorImpl(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data)
: ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data) {}
TensorImpl(const TensorDesc &tensor_desc, const uint8_t *data, size_t size)
: ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data, size) {}
TensorImpl(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data)
: ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), std::move(data)) {}

GeTensor ge_tensor;
};

class ShapeImpl {
public:
ShapeImpl() = default;
~ShapeImpl() = default;
explicit ShapeImpl(const std::vector<int64_t> &dims) {
bool is_unknown_dim_num = false;
for (const auto &dim : dims) {
if (dim == UNKNOWN_DIM_NUM) {
is_unknown_dim_num = true;
break;
}
}
dims_ = is_unknown_dim_num ? std::vector<int64_t>({UNKNOWN_DIM_NUM}) : dims;
}

std::vector<int64_t> dims_;
};

Shape::Shape() { impl_ = ComGraphMakeShared<ShapeImpl>(); }

Shape::Shape(const std::vector<int64_t> &dims) { impl_ = ComGraphMakeShared<ShapeImpl>(dims); }

size_t Shape::GetDimNum() const {
if (impl_ != nullptr) {
for (auto i : impl_->dims_) {
if (i == UNKNOWN_DIM_NUM) {
return 0;
}
}
return impl_->dims_.size();
}
return 0;
}

int64_t Shape::GetDim(size_t idx) const {
if (impl_ != nullptr) {
if (idx >= impl_->dims_.size()) {
return 0;
}
return impl_->dims_[idx];
}
return 0;
}

graphStatus Shape::SetDim(size_t idx, int64_t value) {
if (impl_ != nullptr) {
if (idx >= impl_->dims_.size()) {
return GRAPH_FAILED;
}
impl_->dims_[idx] = value;
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}

std::vector<int64_t> Shape::GetDims() const {
vector<int64_t> dims;
if (impl_ != nullptr) {
return impl_->dims_;
}
return dims;
}

int64_t Shape::GetShapeSize() const {
if (impl_ != nullptr) {
if (impl_->dims_.empty()) {
return 0;
}
int64_t size = 1;
for (auto i : impl_->dims_) {
if (i == UNKNOWN_DIM_NUM || i == UNKNOWN_DIM) {
return UNKNOWN_DIM_SIZE;
}

if (!Int64MulNotOverflow(size, i)) {
GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i);
size = 0;
return size;
}
size *= i;
}
return size;
}
return 0;
}

TensorDesc::TensorDesc() {
impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665
}

TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) {
impl = ComGraphMakeShared<TensorDescImpl>(shape, format, dt); // lint !e665
SetRealDimCnt(shape.GetDimNum());
}

TensorDesc::TensorDesc(const TensorDesc &desc) {
// Copy
impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665
if (desc.impl != nullptr && impl != nullptr) {
*impl = *desc.impl;
}
}

TensorDesc::TensorDesc(TensorDesc &&desc) {
// Move
impl = std::move(desc.impl);
}

TensorDesc &TensorDesc::operator=(const TensorDesc &desc) {
// Copy
if (&desc != this) {
impl = ComGraphMakeShared<TensorDescImpl>();
if (desc.impl != nullptr && impl != nullptr) {
*impl = *desc.impl;
}
}
return *this;
}

TensorDesc &TensorDesc::operator=(TensorDesc &&desc) {
if (&desc != this) {
impl = std::move(desc.impl);
}
return *this;
}

void TensorDesc::Update(const Shape &shape, Format format, DataType dt) {
if (impl != nullptr) {
impl->shape_ = shape;
impl->format_ = format;
impl->data_type_ = dt;
}
}

Shape TensorDesc::GetShape() const {
if (impl != nullptr) {
return impl->shape_;
}
return Shape();
}

void TensorDesc::SetShape(const Shape &shape) {
if (impl != nullptr) {
impl->shape_ = shape;
}
}

// set shape with -2, it stand for unknown shape
graphStatus TensorDesc::SetUnknownDimNumShape() {
if (impl != nullptr) {
impl->shape_ = Shape({UNKNOWN_DIM_NUM});
return GRAPH_SUCCESS;
}
GELOGE(GRAPH_FAILED, "Set unknown shape failed,because no impl class!");
return GRAPH_FAILED;
}

// for unknown shape
graphStatus TensorDesc::SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range) {
if (impl != nullptr) {
impl->range_ = range;
return GRAPH_SUCCESS;
}
GELOGE(GRAPH_FAILED, "SetShapeRange failed!impl is nullptr!");
return GRAPH_FAILED;
}
graphStatus TensorDesc::GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const {
if (impl != nullptr) {
range = impl->range_;
return GRAPH_SUCCESS;
}
GELOGE(GRAPH_FAILED, "impl is nullptr!");
return GRAPH_FAILED;
}

Shape TensorDesc::GetOriginShape() const {
if (impl != nullptr) {
return impl->origin_shape_;
}
return Shape();
}

void TensorDesc::SetOriginShape(const Shape &origin_shape) {
if (impl != nullptr) {
impl->origin_shape_ = origin_shape;
}
}

Format TensorDesc::GetFormat() const {
if (impl != nullptr) {
return impl->format_;
}
return FORMAT_RESERVED;
}

void TensorDesc::SetFormat(Format format) {
if (impl != nullptr) {
impl->format_ = format;
}
}

Format TensorDesc::GetOriginFormat() const {
if (impl != nullptr) {
return impl->origin_format_;
}
return FORMAT_RESERVED;
}

void TensorDesc::SetOriginFormat(Format origin_format) {
if (impl != nullptr) {
impl->origin_format_ = origin_format;
}
}

DataType TensorDesc::GetDataType() const {
if (impl != nullptr) {
return impl->data_type_;
}
return DT_UNDEFINED;
}

void TensorDesc::SetDataType(DataType dt) {
if (impl != nullptr) {
impl->data_type_ = dt;
}
}

void TensorDesc::SetSize(int64_t size) {
if (impl != nullptr) {
impl->size_ = size;
}
}

int64_t TensorDesc::GetSize() const {
if (impl != nullptr) {
return impl->size_;
}
return 0;
}

void TensorDesc::SetRealDimCnt(const int64_t real_dim_cnt) {
if (impl != nullptr) {
impl->real_dim_cnt_ = real_dim_cnt;
}
}

int64_t TensorDesc::GetRealDimCnt() const {
if (impl != nullptr) {
return impl->real_dim_cnt_;
}
return 0;
}

std::string TensorDesc::GetName() const {
if (impl != nullptr) {
return impl->name_;
}
return "";
}

void TensorDesc::SetName(const std::string &name) {
if (impl != nullptr) {
impl->name_ = name;
}
}

Tensor::Tensor() { impl = ComGraphMakeShared<TensorImpl>(); }

Tensor::Tensor(const TensorDesc &tensor_desc) {
impl = ComGraphMakeShared<TensorImpl>(tensor_desc); // lint !e665
}

Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data) {
uint64_t shape_size = tensor_desc.GetShape().GetShapeSize();
DataType data_type = tensor_desc.GetDataType();
uint32_t type_length;
bool ret = TypeUtils::GetDataTypeLength(data_type, type_length);
if (!ret) {
GELOGW("datatype %d is not found.", data_type);
}

auto data_size = data.size();
if (ret && (shape_size || (data_size != type_length))) {
if (type_length != 0 && UINT64_MAX / type_length < shape_size) {
GELOGW("mul overflow: %lu, %u", shape_size, type_length);
} else {
if (shape_size * type_length != data_size) {
GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length,
data_size, TypeUtils::DataTypeToSerialString(data_type).c_str());
}
}
}
impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data); // lint !e665
}

Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) {
uint64_t shape_size = tensor_desc.GetShape().GetShapeSize();
DataType data_type = tensor_desc.GetDataType();
uint32_t type_length;
bool ret = TypeUtils::GetDataTypeLength(data_type, type_length);
if (!ret) {
GELOGW("datatype %d is not found.", data_type);
}
if (ret && (shape_size || (size != type_length))) {
if (type_length != 0 && UINT64_MAX / type_length < shape_size) {
GELOGW("mul overflow: %lu, %u", shape_size, type_length);
} else {
if (shape_size * type_length != size) {
GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length,
size, TypeUtils::DataTypeToSerialString(data_type).c_str());
}
}
}

impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); // lint !e665
}

Tensor::Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) {
uint64_t shape_size = tensor_desc.GetShape().GetShapeSize();
DataType data_type = tensor_desc.GetDataType();
uint32_t type_length;
bool ret = TypeUtils::GetDataTypeLength(data_type, type_length);
if (!ret) {
GELOGW("datatype %d is not found.", data_type);
}

auto data_size = data.size();
if (ret && (shape_size || (data_size != type_length))) {
if (type_length != 0 && UINT64_MAX / type_length < shape_size) {
GELOGW("mul overflow: %lu, %u", shape_size, type_length);
} else {
if (shape_size * type_length != data_size) {
GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length,
data_size, TypeUtils::DataTypeToSerialString(data_type).c_str());
}
}
}
impl = ComGraphMakeShared<TensorImpl>(std::move(tensor_desc), std::move(data)); // lint !e665
}

TensorDesc Tensor::GetTensorDesc() const {
if (impl != nullptr) {
return TensorAdapter::GeTensorDesc2TensorDesc(impl->ge_tensor.MutableTensorDesc());
}
return TensorDesc();
}

graphStatus Tensor::SetTensorDesc(const TensorDesc &tensor_desc) {
if (impl != nullptr) {
impl->ge_tensor.SetTensorDesc(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc));
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}

const uint8_t *Tensor::GetData() const {
if (impl != nullptr) {
return impl->ge_tensor.GetData().data();
}
return nullptr;
}

uint8_t *Tensor::GetData() {
if (impl != nullptr) {
return impl->ge_tensor.MutableData().data();
}
return nullptr;
}

size_t Tensor::GetSize() const {
if (impl != nullptr) {
return impl->ge_tensor.GetData().size();
}
return 0;
}

graphStatus Tensor::SetData(std::vector<uint8_t> &&data) {
if (impl != nullptr) {
(void)impl->ge_tensor.SetData(data);
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}

graphStatus Tensor::SetData(const std::vector<uint8_t> &data) {
if (impl != nullptr) {
(void)impl->ge_tensor.SetData(data);
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}

graphStatus Tensor::SetData(const uint8_t *data, size_t size) {
if (impl != nullptr) {
(void)impl->ge_tensor.SetData(data, size);
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}

graphStatus Tensor::SetData(const std::string &data) {
if (impl != nullptr && (!data.empty())) {
/// Extra 8 bytes store pointer of string
/// Extra 1 byte store '\0'
size_t total_size = data.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL;
std::unique_ptr<char[]> buff(new (std::nothrow) char[total_size]());
if (buff == nullptr) {
GELOGE(GRAPH_FAILED, "allocate string raw data buff failed");
return GRAPH_FAILED;
}
uint64_t *p = reinterpret_cast<uint64_t *>(buff.get());
// Front 8 bytes store pointer of string
char *raw_data = buff.get() + EXTRA_STORE_POINTER_FOR_STRING;
p[0] = reinterpret_cast<uintptr_t>(raw_data);
int32_t memcpy_ret = memcpy_s(raw_data, total_size - EXTRA_STORE_POINTER_FOR_STRING, data.c_str(), data.size() + 1);
GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed");
(void)impl->ge_tensor.SetData(reinterpret_cast<const uint8_t *>(buff.get()), total_size);
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}
graphStatus Tensor::SetData(const std::vector<std::string> &data) {
if (impl != nullptr) {
if (data.empty()) {
GELOGE(GRAPH_FAILED, "there is no data, please check the input variable");
return GRAPH_FAILED;
}
size_t total_size = 0;
for (auto str : data) {
/// Extra 8 bytes store pointer of each string
/// Extra 1 byte store '\0'
total_size += (str.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL);
}
std::unique_ptr<char[]> buff(new (std::nothrow) char[total_size]);
if (buff == nullptr) {
GELOGE(GRAPH_FAILED, "allocate string raw data buff failed");
return GRAPH_FAILED;
}
uint64_t *p = reinterpret_cast<uint64_t *>(buff.get());
// Front some bytes store pointer of each string
char *raw_data = buff.get() + data.size() * sizeof(uint64_t);
uint64_t ptr_size = data.size() * sizeof(uint64_t);
for (size_t i = 0; i < data.size(); ++i) {
p[i] = reinterpret_cast<uintptr_t>(raw_data);
if (total_size < ptr_size) {
GELOGE(GRAPH_FAILED, "Subtraction invalid, total_size: %zu, ptr_size: %lu", total_size, ptr_size);
return GRAPH_FAILED;
}
int32_t memcpy_ret = memcpy_s(raw_data, total_size - ptr_size, data[i].c_str(), data[i].size() + 1);
GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed");
raw_data += (data[i].size() + 1);
ptr_size += (data[i].size() + 1);
}

(void)impl->ge_tensor.SetData(reinterpret_cast<const uint8_t *>(buff.get()), total_size);
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}

graphStatus Tensor::IsValid() {
uint64_t shape_size = GetTensorDesc().GetShape().GetShapeSize();
DataType data_type = GetTensorDesc().GetDataType();
uint32_t type_length;
bool ret = TypeUtils::GetDataTypeLength(data_type, type_length);
if (!ret) {
GELOGW("datatype %d is not found.", data_type);
return GRAPH_SUCCESS;
}

size_t data_size = GetSize();
if (data_type != DT_STRING) {
if (shape_size || (data_size != type_length)) {
if (type_length != 0 && UINT64_MAX / type_length < shape_size) {
GELOGW("mul overflow: %lu, %u", shape_size, type_length);
} else {
if (shape_size * type_length != data_size) {
GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length,
data_size, TypeUtils::DataTypeToSerialString(data_type).c_str());
return GRAPH_FAILED;
}
}
}
}

return GRAPH_SUCCESS;
}

Tensor Tensor::Clone() const {
Tensor tensor;
if (impl != nullptr && tensor.impl != nullptr) {
tensor.impl->ge_tensor = impl->ge_tensor.Clone();
}
return tensor;
}

GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_desc) {
GeTensorDesc ge_tensor_desc(GeShape(tensor_desc.GetShape().GetDims()), tensor_desc.GetFormat(),
tensor_desc.GetDataType());
ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims()));
ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat());
ge_tensor_desc.SetName(tensor_desc.GetName());
std::vector<std::pair<int64_t, int64_t>> shape_range;
auto status = tensor_desc.GetShapeRange(shape_range);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Get shape range failed!");
return ge_tensor_desc;
}
status = ge_tensor_desc.SetShapeRange(shape_range);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Set shape range failed!");
return ge_tensor_desc;
}
auto size = tensor_desc.GetSize();
TensorUtils::SetSize(ge_tensor_desc, size);

auto real_dim_cnt = static_cast<uint32_t>(tensor_desc.GetRealDimCnt());
TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt);
return ge_tensor_desc;
}

TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_desc) {
TensorDesc tensor_desc(Shape(ge_tensor_desc.GetShape().GetDims()), ge_tensor_desc.GetFormat(),
ge_tensor_desc.GetDataType());
tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims()));
tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat());
tensor_desc.SetName(ge_tensor_desc.GetName());
std::vector<std::pair<int64_t, int64_t>> shape_range;
auto status = ge_tensor_desc.GetShapeRange(shape_range);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Get shape range failed!");
return tensor_desc;
}
status = tensor_desc.SetShapeRange(shape_range);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Set shape range failed!");
return tensor_desc;
}
int64_t size = 0;
(void)TensorUtils::GetSize(ge_tensor_desc, size);
tensor_desc.SetSize(size);

uint32_t real_dim_cnt = 0;
(void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt);
tensor_desc.SetRealDimCnt(real_dim_cnt);
return tensor_desc;
}

GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) {
GeTensorPtr ge_tensor;
if (tensor.impl != nullptr) {
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor.Clone()); // lint !e665
}
return ge_tensor;
}

Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) {
Tensor tensor;
if (ge_tensor != nullptr && tensor.impl != nullptr) {
tensor.impl->ge_tensor = ge_tensor->Clone();
}
return tensor;
}

ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) {
GeTensorPtr ge_tensor;
if (tensor.impl != nullptr) {
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665
}
return ge_tensor;
}

GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) {
GeTensorPtr ge_tensor;
if (tensor.impl != nullptr) {
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665
}
return ge_tensor;
}

const GeTensor TensorAdapter::AsGeTensor(const Tensor &tensor) {
if (tensor.impl != nullptr) {
return tensor.impl->ge_tensor;
}
return GeTensor();
}

GeTensor TensorAdapter::AsGeTensor(Tensor &tensor) {
if (tensor.impl != nullptr) {
return tensor.impl->ge_tensor;
}
return GeTensor();
}

const Tensor TensorAdapter::AsTensor(const GeTensor &ge_tensor) {
Tensor tensor;
if (tensor.impl != nullptr) {
tensor.impl->ge_tensor = ge_tensor;
}
return tensor;
}

Tensor TensorAdapter::AsTensor(GeTensor &ge_tensor) {
Tensor tensor;
if (tensor.impl != nullptr) {
tensor.impl->ge_tensor = ge_tensor;
}
return tensor;
}
} // namespace ge

+ 0
- 102
metadef/graph/utils/anchor_utils.cc View File

@@ -1,102 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "utils/anchor_utils.h"
#include <algorithm>
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Format AnchorUtils::GetFormat(const DataAnchorPtr &data_anchor) {
if (data_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "The input data anchor is invalid.");
return FORMAT_RESERVED;
}
return data_anchor->format_;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetFormat(const DataAnchorPtr &data_anchor,
Format data_format) {
if ((data_anchor == nullptr) || (data_format == FORMAT_RESERVED)) {
GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid .");
return GRAPH_FAILED;
}
data_anchor->format_ = data_format;
return GRAPH_SUCCESS;
}

// Get anchor status
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorStatus AnchorUtils::GetStatus(const DataAnchorPtr &data_anchor) {
if (data_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "The input data anchor is invalid.");
return ANCHOR_RESERVED;
}
return data_anchor->status_;
}

// Set anchor status
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetStatus(const DataAnchorPtr &data_anchor,
AnchorStatus anchor_status) {
if ((data_anchor == nullptr) || (anchor_status == ANCHOR_RESERVED)) {
GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid .");
return GRAPH_FAILED;
}
data_anchor->status_ = anchor_status;
return GRAPH_SUCCESS;
}

bool AnchorUtils::HasControlEdge(const AnchorPtr &anchor) {
auto control_anchor = Anchor::DynamicAnchorCast<ControlAnchor>(anchor);
if (control_anchor != nullptr) {
return (control_anchor->GetPeerAnchors().size() != 0);
}

auto data_anchor = Anchor::DynamicAnchorCast<DataAnchor>(anchor);
if (data_anchor) {
for (const auto &peer : data_anchor->GetPeerAnchors()) {
auto peer_cast = Anchor::DynamicAnchorCast<ControlAnchor>(peer);
if (peer_cast) {
return true;
}
}
return false;
}
GELOGE(GRAPH_FAILED, "the anchor is neither control anchor nor data anchor");
return false;
}

bool AnchorUtils::IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst) {
GE_CHK_BOOL_EXEC(src != nullptr, return false, "src is null.");
GE_CHK_BOOL_RET_STATUS_NOLOG(src->IsLinkedWith(dst), false);
auto src_control_anchor = Anchor::DynamicAnchorCast<ControlAnchor>(src);
auto dst_control_anchor = Anchor::DynamicAnchorCast<ControlAnchor>(dst);
return (src_control_anchor || dst_control_anchor);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int AnchorUtils::GetIdx(const AnchorPtr &anchor) {
// Check if it can add edge between DataAnchor
auto data_anchor = Anchor::DynamicAnchorCast<DataAnchor>(anchor);
if (data_anchor != nullptr) {
return data_anchor->GetIdx();
}
// Check if it can add edge between ControlAnchor
auto control_anchor = Anchor::DynamicAnchorCast<ControlAnchor>(anchor);
if (control_anchor != nullptr) {
return control_anchor->GetIdx();
}
return -1;
}
} // namespace ge

+ 0
- 1178
metadef/graph/utils/ge_ir_utils.cc
File diff suppressed because it is too large
View File


+ 0
- 206
metadef/graph/utils/ge_ir_utils.h View File

@@ -1,206 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_
#define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_

#include <google/protobuf/map.h>
#include <google/protobuf/repeated_field.h>
#include <google/protobuf/stubs/port.h>

#include <graph/anchor.h>
#include <graph/debug/ge_log.h>
#include <graph/debug/ge_util.h>
#include <graph/detail/attributes_holder.h>
#include <graph/ge_tensor.h>
#include <graph/graph.h>
#include <graph/model.h>
#include <graph/node.h>
#include <graph/utils/graph_utils.h>
#include <graph/utils/type_utils.h>

#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include "proto/ge_ir.pb.h"
#include "proto/onnx.pb.h"

namespace ge {
const int kOffsetToString = 2;

///
/// @ingroup ge_ir_utils
/// @brief RepeatedField->String
/// @param [in] const rpd_field RepeatedField
/// @return String
///
template <typename T>
const std::string ToString(const google::protobuf::RepeatedField<T> &rpd_field) {
std::stringstream ss;
ss << "[";
for (const T &x : rpd_field) {
ss << x;
ss << ", ";
}
std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString);
str_ret += "]";
return str_ret;
}

///
/// @ingroup ge_ir_utils
/// @brief RepeatedPtrField->String
/// @param [in] const rpd_field RepeatedPtrField
/// @return String
///
template <typename T>
const std::string ToString(const google::protobuf::RepeatedPtrField<T> &rpd_ptr_field) {
std::stringstream ss;
ss << "[";
for (const T &x : rpd_ptr_field) {
ss << x;
ss << ", ";
}
std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString);
str_ret += "]";
return str_ret;
}

///
/// @ingroup ge_ir_utils
/// @brief check, if not equal, log with tag
/// @param [in] const left_value, right_value reference, log_info_tag
/// @return bool
///
template <typename T>
bool IsEqual(const T &l_value, const T &r_value, const std::string &log_info_tag) {
if (l_value == r_value) {
return true;
} else {
GELOGE(GRAPH_FAILED, "Check failed with %s", log_info_tag.c_str());
return false;
}
}

class OnnxUtils {
public:
enum DumpLevel { NO_DUMP = 0, DUMP_ALL = 1, DUMP_WITH_OUT_DATA = 2, DUMP_WITH_OUT_DESC = 3, DUMP_LEVEL_END };

static bool ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto);

static bool ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model);

private:
// Part 1: from IR convert to ONNX Protobuf
static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
const std::string &name, void *data);

static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
const std::string &name, ::google::protobuf::RepeatedField<::google::protobuf::int64> data);

static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
const std::string &name, ::google::protobuf::RepeatedField<bool> data);

static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
const std::string &name, ::google::protobuf::RepeatedField<float> data);

static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
const std::string &name, ::google::protobuf::RepeatedPtrField<::std::string> data);

static void AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto);

static void AddAttrProtoFromAttribute(const std::pair<const std::string, ge::GeAttrValue> &string_attr_value,
onnx::NodeProto *node_proto);

static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc);

static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map<std::string, ge::proto::AttrDef> &attr_map,
onnx::NodeProto *node_proto, const std::string &prefix = "",
const std::string &suffix = "");

static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto);

static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type);

static void EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto);

static bool EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto);

static bool EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto);

static bool EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto);

static void EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type);

static void EncodeValueInfo(const NodePtr &n, onnx::ValueInfoProto *v);

static bool EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto);

/// Part 2: from ONNX Protobuf convert to IR
/// Describes node's link relationships
struct NodeLinkInfo {
std::string src_node_name;
int32_t src_out_index;
NodePtr dst_node;
int32_t dst_in_index;
std::string dst_node_name;
};

// Parse node name and index
static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index);

static ge::DataType DecodeDataType(onnx::TensorProto_DataType data_type);

static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<std::string> &strings);

static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<int64_t> &ints);

static void DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value);

static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value);

static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_output_desc, int32_t index,
OpDescPtr &op_desc);

static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_input_desc, int32_t index,
OpDescPtr &op_desc);

static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_input_output_desc, int32_t index,
OpDescPtr &op_desc);

static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def);

static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc);

static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr);

static bool DecodeNodeLink(const std::vector<onnx::NodeProto> &node_proto_vector,
const std::map<std::string, NodePtr> &node_map);

static bool DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &node);

static bool DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph);
};
} // namespace ge

#endif // COMMON_GRAPH_UTILS_GE_IR_UTILS_H_

+ 0
- 2767
metadef/graph/utils/graph_utils.cc
File diff suppressed because it is too large
View File


+ 0
- 32
metadef/graph/utils/mem_utils.h View File

@@ -1,32 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMMON_GRAPH_UTILS_MEM_UTILS_H_
#define COMMON_GRAPH_UTILS_MEM_UTILS_H_

#include <memory>
#include <utility>

namespace ge {
template <typename _Tp, typename... _Args>
static inline std::shared_ptr<_Tp> MakeShared(_Args &&... __args) {
typedef typename std::remove_const<_Tp>::type _Tp_nc;
std::shared_ptr<_Tp> ret(new (std::nothrow) _Tp_nc(std::forward<_Args>(__args)...));
return ret;
}
}

#endif // COMMON_GRAPH_UTILS_MEM_UTILS_H_

+ 0
- 956
metadef/graph/utils/node_utils.cc View File

@@ -1,956 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "utils/node_utils.h"
#include "utils/op_desc_utils.h"
#include "graph/utils/graph_utils.h"
#include "debug/ge_op_types.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/anchor.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/types.h"
#include "utils/tensor_utils.h"
#include "utils/type_utils.h"

namespace ge {
std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{};
std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{};

const std::set<std::string> kConstOpTypes = {"Const", "Constant"};

const std::set<std::string> kIfOpTypes = {"If", "_If", "StatelessIf"};
const std::set<std::string> kWhileOpTypes = {"While", "_While", "StatelessWhile"};
const std::set<std::string> kCaseOpTypes = {"Case"};
const std::set<std::string> kForOpTypes = {"For"};

bool OpShapeIsUnknown(const OpDescPtr &desc) {
for (const auto &ptr : desc->GetAllInputsDescPtr()) {
auto ge_shape = ptr->GetShape();
for (const auto &dim : ge_shape.GetDims()) {
if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
return true;
}
}
}
for (const auto &ptr : desc->GetAllOutputsDescPtr()) {
auto ge_shape = ptr->GetShape();
for (const auto &dim : ge_shape.GetDims()) {
if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
return true;
}
}
}
return false;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node,
const uint32_t &event_id) {
GE_CHECK_NOTNULL(node);
map_send_info_[node].push_back(event_id);
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node,
const uint32_t &event_id) {
GE_CHECK_NOTNULL(node);
map_recv_info_[node].push_back(event_id);
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send) {
GE_CHECK_NOTNULL(node);
auto find = map_send_info_.find(node);
if (find == map_send_info_.end()) {
return GRAPH_FAILED;
} else {
vec_send = find->second;
return GRAPH_SUCCESS;
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv) {
GE_CHECK_NOTNULL(node);
auto find = map_recv_info_.find(node);
if (find == map_recv_info_.end()) {
return GRAPH_FAILED;
} else {
vec_recv = find->second;
return GRAPH_SUCCESS;
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() {
map_send_info_.clear();
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() {
map_recv_info_.clear();
return GRAPH_SUCCESS;
}

graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) {
GE_CHECK_NOTNULL(src);
NodePtr cur_ptr;
if (depth < 1) {
return GRAPH_FAILED;
}
for (int i = 0; i < depth; i++) {
if (src->GetOutDataNodes().size() != 1) {
return GRAPH_FAILED;
}
cur_ptr = src->GetOutDataNodes().at(0);
GE_CHECK_NOTNULL(cur_ptr);
}
dst = cur_ptr;
return GRAPH_SUCCESS;
}

graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data,
InControlAnchorPtr &in_control) {
GE_CHECK_NOTNULL(node_ptr);
for (const auto &p : node_ptr->GetAllOutDataAnchors()) {
GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr");
for (const auto &p_in : p->GetPeerInControlAnchors()) {
GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr");
out_data = p;
in_control = p_in;
return GRAPH_SUCCESS;
}
}
return GRAPH_FAILED;
}

graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) {
GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED,
"node or in_data_anchor is nullptr");

bool find_flag = false;
uint32_t index = 0;
vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end();
for (const auto &tmp : node_ptr->in_data_anchors_) {
if (tmp == in_data_anchor) {
find_flag = true;
auto iter = node_ptr->in_data_anchors_.begin() + index;
if (iter != node_ptr->in_data_anchors_.end()) {
it = node_ptr->in_data_anchors_.erase(iter);
}
break;
}
index++;
}
for (; it != node_ptr->in_data_anchors_.end(); ++it) {
(*it)->SetIdx(index);
index++;
}

if (!find_flag) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) {
GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr");
GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed");
return GRAPH_SUCCESS;
}

graphStatus NodeUtils::SetAllAnchorStatus(Node &node) {
node.anchor_status_updated_ = true;
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) {
GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr");
return IsAnchorStatusSet(*node_ptr);
}

bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; }

graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) {
if ((origin_node == nullptr) || (new_node == nullptr)) {
return GRAPH_FAILED;
}
auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors();
auto new_out_data_anchors = new_node->GetAllOutDataAnchors();
if (origin_out_data_anchors.size() != new_out_data_anchors.size()) {
return GRAPH_FAILED;
}

for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) {
for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) {
GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
"unlink peer_anchor failed");
GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
"linkto peer_anchor failed");
}

for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) {
GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
"unlink peer_anchor failed");
GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
"linkto peer_anchor failed");
}
}

auto origin_out_control_anchor = origin_node->GetOutControlAnchor();
GE_CHECK_NOTNULL(origin_out_control_anchor);
auto new_out_control_anchor = new_node->GetOutControlAnchor();
GE_CHECK_NOTNULL(new_out_control_anchor);
for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) {
GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
"linkto peer_anchor failed");
}
for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) {
GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
"linkto peer_anchor failed");
}
origin_out_control_anchor->UnlinkAll();

return GRAPH_SUCCESS;
}

bool NodeUtils::IsConst(const Node &node) {
auto src_node_type = node.GetType();
bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP));
return is_const;
}

void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) {
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "node is null");
return;
}
UpdateIsInputConst(*node_ptr);
}

///
/// update is_input_const
/// @param node
/// @return void
///
void NodeUtils::UpdateIsInputConst(Node &node) {
std::vector<bool> is_input_const;
size_t anchor_num = node.GetAllInDataAnchors().size();
for (size_t i = 0; i < anchor_num; i++) {
auto in_anchor = node.GetInDataAnchor(static_cast<int>(i));
if (in_anchor == nullptr) {
is_input_const.push_back(false);
continue;
}
auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
is_input_const.push_back(false);
continue;
}
auto src_node = peer_out_anchor->GetOwnerNode();
if (src_node == nullptr) {
is_input_const.push_back(false);
continue;
}
if (IsConst(*(src_node))) {
is_input_const.push_back(true);
} else {
is_input_const.push_back(false);
}
}
if (node.GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr");
return;
}
node.GetOpDesc()->SetIsInputConst(is_input_const);
}

void NodeUtils::UnlinkAll(const Node &node) {
for (const auto &anchor : node.GetAllOutAnchors()) {
anchor->UnlinkAll();
}
for (const auto &anchor : node.GetAllInAnchors()) {
anchor->UnlinkAll();
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) {
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
return GRAPH_FAILED;
}
auto op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
return GRAPH_FAILED;
}
bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag();
if (is_unknown_graph) {
return GRAPH_SUCCESS;
}
for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
auto out_dims = output_tensor->GetShape().GetDims();
auto out_dtype = output_tensor->GetDataType();
ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
output_tensor->SetOriginShape(output_tensor->GetShape());
output_tensor->SetOriginDataType(output_tensor->GetDataType());

GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());

for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
continue;
}
auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx());
if (peer_input_desc == nullptr) {
GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr");
continue;
}
// check shape and dtype continuity. do not stop process
auto peer_input_dims = peer_input_desc->GetShape().GetDims();
auto peer_input_dtype = peer_input_desc->GetDataType();
if (out_dtype != peer_input_dtype) {
GELOGW(
"current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th "
"input_dtype is [%s].The two dtype should be same! Please check graph and fix it",
node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(),
peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(),
TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str());
} else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) {
string out_shape_str, peer_in_shape_str;
out_shape_str += "[";
for (int64_t dim : out_dims) {
out_shape_str += std::to_string(dim) + " ";
}
out_shape_str += "]";
peer_in_shape_str += "[";
for (int64_t dim : peer_input_dims) {
peer_in_shape_str += std::to_string(dim) + " ";
}
peer_in_shape_str += "]";

GELOGW(
"current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th "
"input_shape is [%s].The two shape should be same! Please check graph and fix it",
node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(),
peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str());
}
GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(),
output_tensor->GetDataType(), output_tensor->GetOriginDataType());
peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
peer_input_desc->SetShape(output_tensor->GetShape());
peer_input_desc->SetDataType(output_tensor->GetDataType());
peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)output_tensor->GetShapeRange(shape_range);
peer_input_desc->SetShapeRange(shape_range);
ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(),
peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType());
}
}
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node,
uint32_t num) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Input node is null");
return GRAPH_FAILED;
}

GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
const auto &op_desc = node->GetOpDesc();
for (size_t i = op_desc->GetInputsSize(); i < num; ++i) {
if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Add input desc failed");
return GRAPH_FAILED;
}

auto anchor = ComGraphMakeShared<InDataAnchor>(node, i);
if (anchor == nullptr) {
GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed.");
return GRAPH_FAILED;
}
node->in_data_anchors_.push_back(anchor);
}

return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node,
uint32_t num) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Input node is null");
return GRAPH_FAILED;
}

const auto &op_desc = node->GetOpDesc();
while (op_desc->GetInputsSize() > num) {
if (!OpDescUtils::ClearInputDesc(op_desc, num)) {
return GRAPH_FAILED;
}
}

auto input_names = op_desc->GetAllInputName();
(void)op_desc->UpdateInputName(input_names);
auto is_input_const = op_desc->GetIsInputConst();
is_input_const.resize(num);
op_desc->SetIsInputConst(is_input_const);

while (node->in_data_anchors_.size() > num) {
node->in_data_anchors_.pop_back();
}

return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node,
uint32_t num) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Input node is null");
return GRAPH_FAILED;
}

GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
const OpDescPtr &op_desc = node->GetOpDesc();
for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) {
if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Add output desc failed");
return GRAPH_FAILED;
}

auto anchor = ComGraphMakeShared<OutDataAnchor>(node, i);
if (anchor == nullptr) {
GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed.");
return GRAPH_FAILED;
}
node->out_data_anchors_.push_back(anchor);
}

return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node,
uint32_t num) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Input node is null");
return GRAPH_FAILED;
}

const auto &op_desc = node->GetOpDesc();
auto output_names = op_desc->GetAllOutputName();
while (op_desc->GetOutputsSize() > num) {
if (!OpDescUtils::ClearOutputDesc(op_desc, num)) {
return GRAPH_FAILED;
}
}
(void)op_desc->UpdateOutputName(output_names);

while (node->out_data_anchors_.size() > num) {
node->out_data_anchors_.pop_back();
}

return GRAPH_SUCCESS;
}

bool NodeUtils::IsInNodesEmpty(const Node &node) {
for (const auto &in_anchor : node.in_data_anchors_) {
if (in_anchor != nullptr) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor != nullptr) {
if (out_anchor->GetOwnerNode() != nullptr) {
return false;
}
}
}
}

if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) {
auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors();
for (const auto &out_control_anchor : peer_out_control_anchors) {
if (out_control_anchor != nullptr) {
if (out_control_anchor->GetOwnerNode() != nullptr) {
return false;
}
}
}
}

return true;
}
GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) {
auto desc = node.GetOpDesc();
if (desc == nullptr) {
return GeTensorDesc();
}
return desc->GetOutputDesc(index);
}
GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) {
auto desc = node.GetOpDesc();
if (desc == nullptr) {
return GeTensorDesc();
}
return desc->GetInputDesc(index);
}
graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) {
auto desc = node.GetOpDesc();
if (desc == nullptr) {
return GRAPH_PARAM_INVALID;
}
auto output_desc = desc->MutableOutputDesc(index);
if (output_desc == nullptr) {
return GRAPH_PARAM_INVALID;
}
output_desc->SetShape(shape);
return GRAPH_SUCCESS;
}
graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) {
auto desc = node.GetOpDesc();
if (desc == nullptr) {
return GRAPH_PARAM_INVALID;
}
auto input_desc = desc->MutableInputDesc(index);
if (input_desc == nullptr) {
return GRAPH_PARAM_INVALID;
}
input_desc->SetShape(shape);
return GRAPH_SUCCESS;
}

graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
auto desc = node.GetOpDesc();
GE_CHECK_NOTNULL(desc);
// check self
is_unknow = OpShapeIsUnknown(desc);
if (is_unknow) {
return GRAPH_SUCCESS;
}
auto sub_graph_names = desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
return GRAPH_SUCCESS;
} else {
auto owner_graph = node.GetOwnerComputeGraph();
GE_CHECK_NOTNULL(owner_graph);
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
if (root_graph == nullptr) {
GE_LOGE("Node %s gets null root graph", node.GetName().c_str());
return GRAPH_PARAM_INVALID;
}
for (auto &sub_graph_name : sub_graph_names) {
auto sub_graph = root_graph->GetSubgraph(sub_graph_name);
GE_CHECK_NOTNULL(sub_graph);
for (const auto &node_ptr : sub_graph->GetDirectNode()) {
auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow);
if (status != GRAPH_SUCCESS) {
GE_LOGE("get node unknown shape status failed!");
return status;
}
if (is_unknow) {
return GRAPH_SUCCESS;
}
}
}
}
return GRAPH_SUCCESS;
}

std::string NodeUtils::GetNodeType(const Node &node) {
if (node.GetType() != FRAMEWORKOP) {
return node.GetType();
}

std::string type;
(void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
return type;
}

std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); }

graphStatus NodeUtils::GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor) {
return GRAPH_SUCCESS;
}

graphStatus NodeUtils::GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor) {
return GRAPH_SUCCESS;
}

ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
auto op_desc = node.GetOpDesc();
if (op_desc == nullptr) {
return nullptr;
}
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
if (root_graph == nullptr) {
return nullptr;
}
return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index));
}

graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) {
if (subgraph == nullptr) {
GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index);
return GRAPH_PARAM_INVALID;
}
auto op_desc = node.GetOpDesc();
if (op_desc == nullptr) {
return GRAPH_PARAM_INVALID;
}
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
if (root_graph == nullptr) {
GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str());
return GRAPH_PARAM_INVALID;
}
auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName());
if (ret != GRAPH_SUCCESS) {
GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index);
return ret;
}
subgraph->SetParentNode(node.shared_from_this());
subgraph->SetParentGraph(node.GetOwnerComputeGraph());
return root_graph->AddSubgraph(subgraph);
}

///
/// Check if node is input of subgraph
/// @param [in] node
/// @return bool
///
bool NodeUtils::IsSubgraphInput(const NodePtr &node) {
if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
return false;
}

auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
if (parent_op_desc == nullptr) {
return false;
}

// dynamic shape unknown graph false
// dynamic shape known graph with functional subgraph maybe true
if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
return false;
} else {
if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
return false;
}
}
}

return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX);
}

///
/// Check if node is output of subgraph
/// @param [in] node
/// @return bool
///
bool NodeUtils::IsSubgraphOutput(const NodePtr &node) {
if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) {
return false;
}

auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
if (parent_op_desc == nullptr) {
return false;
}

if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
return false;
} else {
if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
return false;
}
}
}

for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) {
if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) {
return true;
}
}

return false;
}

///
/// @brief Get subgraph original input node.
/// @param [in] node
/// @return Node
///
NodePtr NodeUtils::GetParentInput(const Node &node) {
uint32_t parent_index = 0;
if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
return nullptr;
}

// Subgraph Data Node, check for constant input.
const ComputeGraphPtr &graph = node.GetOwnerComputeGraph();
GE_CHECK_NOTNULL_EXEC(graph, return nullptr);

const NodePtr &parent_node = graph->GetParentNode();
GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr);

const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index);
GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr);

const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr);

return peer_out_anchor->GetOwnerNode();
}

NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return node == nullptr ? node : GetParentInput(*node); }

///
/// @brief Get is dynamic shape graph from node.
/// @param [in] node
/// @return bool
///
bool NodeUtils::IsDynamicShape(const Node &node) {
const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
if (graph == nullptr) {
return false;
}

bool is_dynamic_shape = false;
(void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape);
return is_dynamic_shape;
}

bool NodeUtils::IsDynamicShape(const NodePtr &node) { return node == nullptr ? false : IsDynamicShape(*node); }

///
/// @brief Check is varying_input for while node
/// @param [in] node: Data node for subgraph
/// @return bool
///
bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) {
if (node == nullptr) {
return false;
}
if (node->GetType() != DATA) {
return false; // not input_node for subgraph
}

const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode();
if (parent_node == nullptr) {
return false; // root graph
}

if (kWhileOpTypes.count(parent_node->GetType()) == 0) {
return false; // not input_node for while subgraph
}

uint32_t index_i = 0;
if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) {
GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str());
return false;
}
bool varying_flag = true;
for (const auto &item : node->GetOutDataNodesAndAnchors()) {
if (item.first->GetType() != NETOUTPUT) {
continue;
}
OpDescPtr op_desc = item.first->GetOpDesc();
uint32_t index_o = 0;
if ((op_desc == nullptr) ||
!AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) {
continue; // input for while-cond subgraph
}
if (index_i != index_o) {
continue; // varying input for while-body subgraph
}
varying_flag = false;
break;
}
return varying_flag;
}

///
/// @brief Get subgraph input is constant.
/// @param [in] node
/// @param [out] string
/// @return bool
///
bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) {
if (node == nullptr) {
return false;
}

if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) {
type = node->GetType();
return true;
}

if (node->GetType() != DATA) {
return false; // not subgraph input node
}

const auto &parent = GetParentInput(node);
return GetConstOpType(parent, type);
}

///
/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph.
/// @param [in] node
/// @return return GRAPH_SUCCESS if remove successfully, other for failed.
///
Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
return GRAPH_SUCCESS;
} else {
auto owner_graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(owner_graph);
auto root_graph = GraphUtils::FindRootGraph(owner_graph);
GE_CHECK_NOTNULL(root_graph);

std::unordered_set<std::string> subgraph_to_remove;
for (auto &subgraph_name : subgraph_names) {
std::deque<std::string> queue;
queue.push_back(subgraph_name);
subgraph_to_remove.insert(subgraph_name);
op_desc->RemoveSubgraphInstanceName(subgraph_name);
while (!queue.empty()) {
auto graph_name = queue.front();
queue.pop_front();

auto subgraph = root_graph->GetSubgraph(graph_name);
GE_CHECK_NOTNULL(subgraph);
for (const auto &sub_node : subgraph->GetDirectNode()) {
auto sub_op_desc = sub_node->GetOpDesc();
GE_CHECK_NOTNULL(sub_op_desc);
auto sub_names = sub_op_desc->GetSubgraphInstanceNames();
// Subgraph and all nodes in it will be removed later,
// no need to remove 'SubgraphInstanceName' in op desc here.
for (auto &name : sub_names) {
if (subgraph_to_remove.insert(name).second) {
queue.push_back(name);
}
}
}
}
}
// Remove subgraph from root_graph
for (const auto &name : subgraph_to_remove) {
GELOGI("Remove subgraph:%s.", name.c_str());
root_graph->RemoveSubgraph(name);
}
}

return GRAPH_SUCCESS;
}
///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) {
vector<NodePtr> in_data_node_vec;
auto op_desc = node.GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
GELOGW("Node %s is single node without sub graph.", node.GetName().c_str());
return in_data_node_vec;
}
auto compute_graph = node.GetOwnerComputeGraph();
for (const std::string &instance_name : subgraph_names) {
auto subgraph = compute_graph->GetSubgraph(instance_name);
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
int parent_index = -1;
if (NodeUtils::IsSubgraphInput(node_in_subgraph)) {
(void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index);
if (parent_index == index) {
in_data_node_vec.emplace_back(node_in_subgraph);
}
}
}
}
return in_data_node_vec;
}
///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) {
vector<NodePtr> out_data_node_vec;
auto op_desc = node.GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
GELOGI("Node %s is single node without sub graph.", node.GetName().c_str());
return out_data_node_vec;
}
auto compute_graph = node.GetOwnerComputeGraph();
for (const std::string &instance_name : subgraph_names) {
auto subgraph = compute_graph->GetSubgraph(instance_name);
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) {
out_data_node_vec.emplace_back(node_in_subgraph);
}
}
}
return out_data_node_vec;
}

NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) {
if (node.GetInDataAnchor(index) == nullptr) {
return nullptr;
}
if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) {
return nullptr;
}
return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode();
}

vector<pair<InDataAnchorPtr, NodePtr>> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) {
vector<pair<InDataAnchorPtr, NodePtr>> out_data_nodes;
auto out_data_anchor = node.GetOutDataAnchor(index);
if (out_data_anchor == nullptr) {
return out_data_nodes;
}

for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
if (peer_in_anchor == nullptr) {
continue;
}
if (peer_in_anchor->GetOwnerNode() == nullptr) {
continue;
}
out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode()));
}
return out_data_nodes;
}

ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); }
} // namespace ge

+ 0
- 778
metadef/graph/utils/op_desc_utils.cc View File

@@ -1,778 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "utils/op_desc_utils.h"
#include <algorithm>
#include "debug/ge_attr_define.h"
#include "debug/ge_op_types.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/anchor.h"
#include "graph/compute_graph.h"
#include "graph/ge_attr_value.h"
#include "utils/graph_utils.h"
#include "utils/node_utils.h"

using std::vector;

/*lint -e512 -e737 -e752*/
namespace ge {
const char OP_DESC_QUANT_PARAMS[] = "quantize_factor";
static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1;

bool OpDescUtils::ClearInputDesc(const NodePtr &node) {
GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
vector<int> index_list;
for (const auto &in_anchor : node->GetAllInDataAnchors()) {
if (in_anchor->GetPeerOutAnchor() == nullptr) {
index_list.push_back(in_anchor->GetIdx());
}
}
std::sort(index_list.begin(), index_list.end());
// Node's in anchor index need shrink
for (size_t i = 0; i < index_list.size(); ++i) {
auto iter = node->GetOpDesc()->inputs_desc_.begin() + index_list[i];
if (iter < node->GetOpDesc()->inputs_desc_.end()) {
(void)node->GetOpDesc()->inputs_desc_.erase(iter);
} else {
GELOGW("inputs_desc_ iterator out of range.");
}
}

return true;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(OpDescPtr op_desc,
const uint32_t index) {
GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
GE_CHK_BOOL_EXEC(index < op_desc->inputs_desc_.size(), return false, "index %u is invalid.", index);

auto iter = op_desc->inputs_desc_.begin() + index;
if (iter < op_desc->inputs_desc_.end()) {
(void)op_desc->inputs_desc_.erase(iter);
} else {
GELOGW("inputs_desc_ iterator out of range.");
}
return true;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) {
GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return false, "op_desc is nullptr");
return op_desc->HasAttr(OP_DESC_QUANT_PARAMS);
}

bool OpDescUtils::ClearOutputDesc(const NodePtr &node) {
GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr");
GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr");
vector<int> index_list;
for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
if (out_anchor->GetPeerInDataAnchors().empty()) {
index_list.push_back(out_anchor->GetIdx());
}
}
std::sort(index_list.begin(), index_list.end());
// Node's out anchor index need shrink
for (size_t i = 0; i < index_list.size(); ++i) {
auto iter = node->GetOpDesc()->outputs_desc_.begin() + index_list[i];
if (iter < node->GetOpDesc()->outputs_desc_.end()) {
(void)node->GetOpDesc()->outputs_desc_.erase(iter);
} else {
GELOGW("outputs_desc_ iterator out of range.");
}
}

return true;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc,
uint32_t index) {
GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr");
GE_CHK_BOOL_EXEC(index < op_desc->outputs_desc_.size(), return false, "index %u is invalid.", index);

auto iter = op_desc->outputs_desc_.begin() + index;
if (iter < op_desc->outputs_desc_.end()) {
(void)op_desc->outputs_desc_.erase(iter);
} else {
GELOGW("outputs_desc_ iterator out of range.");
}
return true;
}

bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
OpDescUtils::GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant) {
GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
GeAttrValue attr_value;
GE_CHK_BOOL_EXEC_INFO(op_desc->GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
"GetQuantizeFactorParams failed");
return attr_value.GetValue<QuantizeFactorParams>(quant);
}

graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant) {
GeAttrValue attr_value;
GE_CHK_BOOL_EXEC_INFO(op_desc.GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED,
"GetQuantizeFactorParams failed");
return attr_value.GetValue<QuantizeFactorParams>(quant);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) {
GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
}

graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) {
return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
}

GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) {
GeTensorPtr weight = nullptr;
if (!AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight)) {
GELOGW("MutableTensor error");
}

return weight;
}

GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(OpDescPtr op_desc) {
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "op_desc is null");
return nullptr;
}
return MutableWeights(*op_desc);
}

graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) {
if (weight == nullptr) {
GELOGE(GRAPH_FAILED, "weight is null");
return GRAPH_FAILED;
}
return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED;
}

graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) {
GE_CHECK_NOTNULL(op_desc);
GE_CHECK_NOTNULL(weight);
return SetWeights(*op_desc, weight);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(const ge::Node &node) {
auto weights = MutableWeights(node);
vector<ConstGeTensorPtr> ret(weights.size());
std::copy(weights.begin(), weights.end(), ret.begin());
return ret;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(
const ge::ConstNodePtr &node) {
if (node == nullptr) {
return vector<ge::ConstGeTensorPtr>();
}
return GetWeights(*node);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputNode(
const ge::Node &node) {
vector<ge::NodePtr> ret;
auto in_anchors = node.GetAllInDataAnchors();
for (const auto &in_anchor : in_anchors) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) {
// normally out_anchor could be null, this is ok
GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str());
continue;
}
auto in_node = out_anchor->GetOwnerNode();
while (true) {
if (in_node == nullptr) {
break;
}
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
ret.push_back(in_node);
break;
} else if (in_node->GetType() == DATA) {
if (NodeUtils::IsWhileVaryingInput(in_node)) {
break;
}
in_node = NodeUtils::GetParentInput(in_node);
} else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) {
bool is_constant = false;
(void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant);
if (!is_constant) {
break;
}
// Enter node has and only has one input
if (in_node->GetInDataNodes().size() != 1) {
GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(),
in_node->GetInDataNodes().size());
break;
}
in_node = in_node->GetInDataNodes().at(0);
} else {
break;
}
}
}
return ret;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData(
const vector<ge::NodePtr> &input_nodes) {
vector<ConstGeTensorPtr> ret;

for (const auto &input_node : input_nodes) {
auto temp_weight = MutableWeights(input_node->GetOpDesc());
if (temp_weight == nullptr) {
GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
return vector<ConstGeTensorPtr>();
}
ret.push_back(temp_weight);
}

return ret;
}
size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) {
if (NodeUtils::IsAnchorStatusSet(node)) {
size_t input_num = 0;
for (const auto &anchor : node.GetAllInDataAnchors()) {
if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
input_num++;
continue;
}
}
return input_num; // lint !e712
} else {
GE_IF_BOOL_EXEC(
node.GetInDataNodes().size() < GetConstInputs(node).size(),
GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size());
return 0);
return node.GetInDataNodes().size() - GetConstInputs(node).size();
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Node is nullptr");
return 0;
}
return GetNonConstInputsSize(*node);
}

GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const) {
GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GeTensorDesc(), "node.GetOpDesc() is nullptr!");
size_t i = 0;
if (NodeUtils::IsAnchorStatusSet(node)) {
for (const auto &anchor : node.GetAllInDataAnchors()) {
if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
if (index_non_const == i) {
return node.GetOpDesc()->GetInputDesc(static_cast<uint32_t>(anchor->GetIdx()));
}
++i;
}
}
} else {
for (const auto &anchor : node.GetAllInDataAnchors()) {
auto peer_anchor = anchor->GetPeerOutAnchor();
if (peer_anchor == nullptr) {
continue;
}
auto owner_node = peer_anchor->GetOwnerNode();
if (owner_node == nullptr) {
continue;
}
if (owner_node->GetType() == CONSTANT) {
continue;
}
if (index_non_const == i) {
return node.GetOpDesc()->GetInputDesc(anchor->GetIdx());
}
++i;
}
}
return GeTensorDesc();
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc
OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const) {
CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc());
return GetNonConstInputTensorDesc(*node, index_non_const);
}

bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) {
bool ret = false;
size_t i = 0;
if (NodeUtils::IsAnchorStatusSet(node)) {
for (const auto &anchor : node.GetAllInDataAnchors()) {
if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) {
if (index_non_const == i) {
index = static_cast<size_t>(anchor->GetIdx());
ret = true;
}
++i;
}
}
} else {
for (const auto &anchor : node.GetAllInDataAnchors()) {
auto peer_anchor = anchor->GetPeerOutAnchor();
if (peer_anchor == nullptr) {
continue;
}
auto owner_node = peer_anchor->GetOwnerNode();
if (owner_node == nullptr) {
continue;
}
if (owner_node->GetType() == CONSTANT) {
continue;
}
if (index_non_const == i) {
index = static_cast<size_t>(anchor->GetIdx());
ret = true;
}
++i;
}
}
return ret;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node,
size_t index_non_const,
size_t &index) {
CHECK_FALSE_EXEC(node != nullptr, return false);
return GetNonConstInputIndex(*node, index_non_const, index);
}

bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) {
bool ret = false;
if (index < node.GetAllInDataAnchors().size()) {
if (NodeUtils::IsAnchorStatusSet(node)) {
ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); // lint !e712
} else {
for (const auto &anchor : node.GetAllInDataAnchors()) {
if (anchor->GetIdx() != static_cast<int>(index)) {
continue;
}
auto peer_anchor = anchor->GetPeerOutAnchor();
if (peer_anchor == nullptr) {
break;
}
auto owner_node = peer_anchor->GetOwnerNode();
if (owner_node == nullptr) {
break;
}
ret = (owner_node->GetType() != CONSTANT);
}
}
}

return ret;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node,
size_t index) {
CHECK_FALSE_EXEC(node != nullptr, return false);
return IsNonConstInput(*node, index);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(
const ge::ConstNodePtr &node) {
if (node == nullptr) {
return vector<ge::NodePtr>();
}
return GetConstInputs(*node);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUtils::GetNonConstTensorDesc(
const ge::ConstNodePtr &node) {
if (node == nullptr || node->GetOpDesc() == nullptr) {
return vector<ge::GeTensorDesc>();
}
vector<ge::GeTensorDesc> ret;
if (NodeUtils::IsAnchorStatusSet(*node)) {
for (const auto &in_anchor : node->GetAllInDataAnchors()) {
if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) {
ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
}
}
} else {
for (const auto &in_anchor : node->GetAllInDataAnchors()) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr || out_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
continue;
}
if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) {
ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
}
}
}
return ret;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(const ge::Node &node) {
vector<ge::NodePtr> ret;
auto in_anchors = node.GetAllInDataAnchors();
for (const auto &in_anchor : in_anchors) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) continue;

auto in_node = out_anchor->GetOwnerNode();
if (in_node->GetType() == CONSTANT) {
ret.push_back(in_node);
} else if (in_node->GetType() == SWITCH && node.GetType() == MATMUL) {
// const --> switch --> matmul
auto switch_input = GetConstInputs(*in_node);
if (switch_input.size() > 0) {
ret.insert(ret.end(), switch_input.begin(), switch_input.end());
}
} else if (in_node->GetType() == DATA) {
auto parent = NodeUtils::GetParentInput(in_node);
if ((parent != nullptr) && (parent->GetType() == CONSTANT)) {
ret.push_back(parent);
}
}
}
return ret;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) {
vector<GeTensorPtr> ret;
auto op_desc = node.GetOpDesc();
GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!");
// Place holder operator, try to get the weight from parent node
// when parent node is const operator
if (node.GetType() == PLACEHOLDER) {
std::string parent_op;
(void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op);
// This if judgment is necessary because the current subgraph optimization is multithreaded
// and the parent node of the PLD operation should be a stable type, such as const
if (parent_op == CONSTANT || parent_op == CONSTANTOP) {
NodePtr parent_node = nullptr;
parent_node = op_desc->TryGetExtAttr("parentNode", parent_node);
if (parent_node != nullptr) {
op_desc = parent_node->GetOpDesc();
GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str());
}
}
}
// Const operator, take the weight directly
if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) {
auto weight = MutableWeights(op_desc);
if (weight == nullptr) {
GELOGI("const op has no weight, op name:%s", node.GetName().c_str());
return ret;
}
ret.push_back(weight);
return ret;
}

if (node.GetType() == DATA) {
auto parent = NodeUtils::GetParentInput(node);
if ((parent != nullptr) && NodeUtils::IsConst(*parent)) {
auto weight = MutableWeights(parent->GetOpDesc());
if (weight == nullptr) {
GELOGI("const op has no weight, op name:%s", parent->GetName().c_str());
return ret;
}
ret.push_back(weight);
}
return ret;
}

// Other operators, get weights from connected constop
auto input_nodes = GetConstInputs(node);
for (const auto &input_node : input_nodes) {
auto temp_weight = MutableWeights(input_node->GetOpDesc());
if (temp_weight == nullptr) {
GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str());
return vector<GeTensorPtr>();
}
ret.push_back(temp_weight);
}

return ret;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::NodePtr node) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Node is nullptr");
return vector<ge::GeTensorPtr>();
}
return MutableWeights(*node);
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
OpDescUtils::SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights) {
GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GRAPH_PARAM_INVALID, "node.GetOpDesc is nullptr!");
if (node.GetOpDesc()->GetType() == CONSTANT) {
if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) {
return SetWeights(node.GetOpDesc(), weights[0]);
}
GELOGI("const op weight size %zu should be 1", weights.size());
return GRAPH_PARAM_INVALID;
}

auto input_nodes = GetConstInputs(node);
if (weights.size() < input_nodes.size()) {
GELOGE(GRAPH_FAILED, "weights count can't be less than const input count");
return GRAPH_PARAM_INVALID;
}

ge::GeAttrValue::NAMED_ATTRS named_attrs;
(void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights);
vector<ge::GeTensorPtr> copy_weights;
(void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights);

for (size_t i = 0; i < input_nodes.size(); ++i) {
if (input_nodes[i]->GetOpDesc() != nullptr) {
SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]);
}
}

// If set more weights than constop, need to add constop
for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) {
// Use org weight before SetWeights Overwrite
auto const_opdesc = CreateConstOp(copy_weights[i]);
GE_CHECK_NOTNULL(const_opdesc);

auto owner_graph = node.GetOwnerComputeGraph();
if (owner_graph == nullptr) {
GELOGE(GRAPH_FAILED, "node's graph is empty, name: %s", node.GetName().c_str());
return GRAPH_PARAM_INVALID;
}
auto const_node = owner_graph->AddNodeFront(const_opdesc);
GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS, return GRAPH_FAILED, "graph add link failed!");
std::vector<ge::NodePtr> original_nodes;
ge::GraphUtils::RecordOriginalNames(original_nodes, const_node);
}
return GRAPH_SUCCESS;
}

OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) {
GE_CHK_BOOL_EXEC(tensor_ptr != nullptr, return nullptr, "tensor_ptr is nullptr!");
shared_ptr<OpDesc> const_opdesc = ComGraphMakeShared<OpDesc>();
if (const_opdesc == nullptr) {
GELOGE(GRAPH_FAILED, "failed to make_shared ");
return nullptr;
}

CHECK_FALSE_EXEC(SetWeights(const_opdesc, tensor_ptr) == ge::GRAPH_SUCCESS, return nullptr);

const_opdesc->SetType(CONSTANT);

thread_local int64_t const_count = 0;
const_opdesc->SetName("dynamic_const_" + std::to_string(GetTid()) + "_" + std::to_string(const_count));
GELOGI("add const op: %s", const_opdesc->GetName().c_str());
++const_count;

(void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc());

GELOGI("after add const op: %s", const_opdesc->GetName().c_str());

return const_opdesc;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
OpDescUtils::AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) {
GE_CHECK_NOTNULL(in_anchor);
GE_CHECK_NOTNULL(tensor_ptr);
auto const_opdesc = CreateConstOp(tensor_ptr);
GE_CHECK_NOTNULL(const_opdesc);
auto in_node = in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(in_node);
auto owner_graph = in_node->GetOwnerComputeGraph();
if (owner_graph == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", in_node->GetName().c_str());
return GRAPH_PARAM_INVALID;
}
auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc);
GE_CHECK_NOTNULL(const_node);
if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) {
GELOGE(GRAPH_PARAM_INVALID, "Addedge const to node failed.");
return GRAPH_PARAM_INVALID;
}
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
OpDescUtils::SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights) {
GE_CHECK_NOTNULL(node);
return SetWeights(*node, weights);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) {
GE_CHECK_NOTNULL(node);
auto const_ops = GetConstInputs(node);
auto graph = node->GetOwnerComputeGraph();
if (graph == nullptr) {
GELOGE(GRAPH_FAILED, "Graph is nullptr");
return GRAPH_PARAM_INVALID;
}
for (const auto &const_op : const_ops) {
GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "Isolate removed node: %s, type: %s failed",
const_op->GetName().c_str(), const_op->GetType().c_str());
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op),
"Remove node: %s, type: %s without relink failed", const_op->GetName().c_str(),
const_op->GetType().c_str());
}
return GRAPH_SUCCESS;
}

///
/// @brief Add input
/// @param [in] name
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) {
inputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
return *this;
}

///
/// @brief Add input
/// @param [in] name
/// @param [in] tensor
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name,
const GeTensorDesc &tensor) {
inputs_.emplace_back(std::make_pair(name, tensor));
return *this;
}

///
/// @brief Add dynamic input
/// @param [in] name
/// @param [in] num
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name,
uint32_t num) {
for (uint32_t i = 0; i < num; i++) {
inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
}
return *this;
}

///
/// @brief Add dynamic input
/// @param [in] name
/// @param [in] num
/// @param [in] tensor
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(
const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
for (uint32_t i = 0; i < num; i++) {
inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
}
return *this;
}

///
/// @brief Add output
/// @param [in] name
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) {
outputs_.emplace_back(std::make_pair(name, GeTensorDesc()));
return *this;
}

///
/// @brief Add output
/// @param [in] name
/// @param [in] tensor
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name,
const GeTensorDesc &tensor) {
outputs_.emplace_back(std::make_pair(name, tensor));
return *this;
}

///
/// @brief Add dynamic output
/// @param [in] name
/// @param [in] num
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name,
uint32_t num) {
for (uint32_t i = 0; i < num; i++) {
outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc()));
}
return *this;
}

///
/// @brief Add dynamic output
/// @param [in] name
/// @param [in] num
/// @param [in] tensor
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(
const std::string &name, uint32_t num, const GeTensorDesc &tensor) {
for (uint32_t i = 0; i < num; i++) {
outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor));
}
return *this;
}

///
/// @brief Build op_desc
/// @return OpDescPtr
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() {
OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_));
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "OpDesc is nullptr");
return nullptr;
}

for (auto &input : inputs_) {
if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Add input_desc failed.");
return nullptr;
}
}

for (auto &output : outputs_) {
if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Add output_desc failed.");
return nullptr;
}
}

return op_desc;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName(
const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) {
const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes();
auto iter = subgraph_names_to_index.find(subgraph_name);
if (iter == subgraph_names_to_index.end()) {
GELOGE(GRAPH_PARAM_INVALID,
"Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists",
subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(),
subgraph_name.c_str());
return GRAPH_PARAM_INVALID;
}

return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name);
}
} // namespace ge
/*lint +e512 +e737 +e752*/

+ 0
- 68
metadef/graph/utils/string_utils.h View File

@@ -1,68 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMMON_GRAPH_UTILS_STRING_UTILS_H_
#define COMMON_GRAPH_UTILS_STRING_UTILS_H_

#include <algorithm>
#include <functional>
#include <sstream>
#include <string>
#include <vector>
#include "securec.h"

namespace ge {
class StringUtils {
public:
static std::string &Ltrim(std::string &s) {
(void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); }));
return s;
}

static std::string &Rtrim(std::string &s) {
(void)s.erase(std::find_if(s.rbegin(), s.rend(), [](int c) { return !std::isspace(c); }).base(), s.end());
return s;
}

/// @ingroup domi_common
/// @brief trim space
static std::string &Trim(std::string &s) { return Ltrim(Rtrim(s)); }

// split string
static std::vector<std::string> Split(const std::string &str, char delim) {
std::vector<std::string> elems;

if (str.empty()) {
elems.emplace_back("");
return elems;
}

std::stringstream ss(str);
std::string item;

while (getline(ss, item, delim)) {
elems.push_back(item);
}
auto str_size = str.size();
if (str_size > 0 && str[str_size - 1] == delim) {
elems.emplace_back("");
}

return elems;
}
};
} // namespace ge
#endif // COMMON_GRAPH_UTILS_STRING_UTILS_H_

+ 0
- 401
metadef/graph/utils/tensor_utils.cc View File

@@ -1,401 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/utils/tensor_utils.h"
#include <cmath>

#include "debug/ge_log.h"
#include "framework/common/debug/ge_log.h"
#include "common/util/error_manager/error_manager.h"
#include "graph/ge_tensor.h"
#include "graph/types.h"
#include "graph/utils/type_utils.h"

namespace ge {
namespace {
// When nc1hwc0 dim size = 5, calc element count directly.
const uint32_t kNc1hwc0CalcByDimsSize = 5;

// Unknown shape element num
const int64_t kElementCntUnknownShape = -1;

// Unknown shape mem size
const int64_t kMemSizeUnknownShape = -1;

// Nchw and nhwc dim size must be 4
const uint32_t kDimSize4d = 4;

// C1HWNCoC0 dim size must be 6
const uint32_t kDimSizeC1hwncoc0 = 6;

// Cube size is 16
const uint32_t kTheCubeSize = 16;

// Default c0 size equals cube size.
const uint32_t kC0SizeDefault = kTheCubeSize;

// Size equals int8 cube size is 32
const uint32_t kC0SizeInt8 = 32;

// NCHW dim N index
const int32_t kNchwDimIdxN = 0;
// NCHW dim C index
const int32_t kNchwDimIdxC = 1;
// NCHW dim H index
const int32_t kNchwDimIdxH = 2;
// NCHW dim W index
const int32_t kNchwDimIdxW = 3;

const int kDataMemAlignSize = 32;
const int kNum2 = 2;
} // namespace

///
/// Check if a * b overflow.
/// @param a multiplier
/// @param b Multiplicand
/// @return true: overflow
/// false: not overflow
///
static bool CheckMultiplyOverflowInt64(const int64_t &a, const int64_t &b) {
if (a > 0) {
if (b > 0) {
if (a > (INT64_MAX / b)) {
return true;
}
} else {
if (b < (INT64_MIN / a)) {
return true;
}
}
} else {
if (b > 0) {
if (a < (INT64_MIN / b)) {
return true;
}
} else {
if ((a != 0) && (b < (INT64_MAX / a))) {
return true;
}
}
}
return false;
}

///
/// Calculate element num by dims directly.
/// @param dims dim info
/// @param element_cnt element count
/// @return GRAPH_SUCCESS:success
/// other:failed
///
static graphStatus CalcElementCntByDims(const std::vector<int64_t> &dims, int64_t &element_cnt) {
element_cnt = 1;
for (int64_t dim : dims) {
if (CheckMultiplyOverflowInt64(element_cnt, dim)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19013", {"function", "var1", "var2"},
{"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)});
GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim);
return GRAPH_FAILED;
}
element_cnt *= dim;
}
return GRAPH_SUCCESS;
}

///
/// Calculate fixed dims element num.
/// @param dims dim info
/// @param fixed_dim_size fixed dim size
/// @param element_cnt element count
/// @return GRAPH_SUCCESS:success
/// other:failed
///
static graphStatus CalcElementCntOfFixedDims(const std::vector<int64_t> &dims, Format format, uint32_t fixed_dim_size,
int64_t &element_cnt) {
if (dims.size() != fixed_dim_size) {
GELOGW("Format %d(%s) need dim size=%u but %zu, calc as ND.", format,
TypeUtils::FormatToSerialString(format).c_str(), fixed_dim_size, dims.size());
}
return CalcElementCntByDims(dims, element_cnt);
}

///
/// Get dim c0 size by type
/// @param data_type data type
/// @return c0 size
///
static uint32_t GetDimC0(DataType &data_type) {
bool is_int8_size = (data_type == DT_INT8) || (data_type == DT_UINT8) || (data_type == DT_DUAL_SUB_UINT8) ||
(data_type == DT_DUAL_SUB_INT8) || (data_type == DT_BOOL) || (data_type == DT_QINT8);
return is_int8_size ? kC0SizeInt8 : kC0SizeDefault;
}

///
/// Calculate nc1hwc0 element num.
/// @param dims dim info
/// @param data_type data type
/// @param element_cnt element count
/// @return GRAPH_SUCCESS:success
/// other:failed
///
static graphStatus CalcElementCntOfNc1hwc0(const std::vector<int64_t> &dims, DataType data_type, int64_t &element_cnt) {
// When nc1hwc0 dims size = 5, no need split dim c
if (dims.size() == kNc1hwc0CalcByDimsSize) {
return CalcElementCntByDims(dims, element_cnt);
} else if (dims.size() != kDimSize4d) {
GELOGE(GRAPH_FAILED, "CalcElementCntOfNc1hwc0 failed as dims.size=%zu is not %u or %u.", dims.size(), kDimSize4d,
kNc1hwc0CalcByDimsSize);
return GRAPH_FAILED;
}

auto c0 = static_cast<int64_t>(GetDimC0(data_type));
// Nc1hwc0 dims is according to nchw, dim c index is 1.
auto c1 = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0));
// Store dims is split c to c1 and c0.
std::vector<int64_t> store_dims = {dims[kNchwDimIdxN], c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0};
return CalcElementCntByDims(store_dims, element_cnt);
}

///
/// Calculate FractalZ element num.
/// @param dims dim info
/// @param data_type data type
/// @param element_cnt element count
/// @return GRAPH_SUCCESS:success
/// other:failed
///
static graphStatus CalcElementCntOfFractalZ(const std::vector<int64_t> &dims, DataType data_type,
int64_t &element_cnt) {
static char *parser_priority = std::getenv("PARSER_PRIORITY");
if (parser_priority != nullptr && string(parser_priority) == "cce") {
if (dims.size() != kDimSize4d) {
GELOGE(GRAPH_FAILED, "CalcElementCntOfFractalZ failed as dims.size=%zu is not %u.", dims.size(), kDimSize4d);
return GRAPH_FAILED;
}
auto c0 = static_cast<int64_t>(GetDimC0(data_type));
// FractalZ dims is according to nchw, dim c index is 1.
auto c1 = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0));

// Spread NC1HWC0 as a two dimension array, n as column dimension,
// C1HWC0 as row dimension
std::vector<int64_t> r_count_vec = {c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0};

int64_t r_count = 1;
graphStatus graph_status = CalcElementCntByDims(r_count_vec, r_count);
if (graph_status != GRAPH_SUCCESS) {
GELOGE(graph_status, "Calc [%ld, %ld, %ld, %ld] element count failed.", c1, dims[kNchwDimIdxH],
dims[kNchwDimIdxW], c0);
return graph_status;
}

// Cube count in n
auto nc_cnt = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxN] * 1.0 / kTheCubeSize));

// Cube count in vertical direction(C1HWC0)
int64_t vc_cnt = r_count / c0;
// Element count in each cube
int64_t cube_elem_cnt = c0 * kTheCubeSize;

if (CheckMultiplyOverflowInt64(nc_cnt, vc_cnt)) {
GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", nc_cnt, vc_cnt);
return GRAPH_FAILED;
}
// Read data times needed by cube
int64_t c_cnt = nc_cnt * vc_cnt;

if (CheckMultiplyOverflowInt64(c_cnt, cube_elem_cnt)) {
GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", c_cnt, cube_elem_cnt);
return GRAPH_FAILED;
}
// Element count after fractal arrangement
element_cnt = c_cnt * cube_elem_cnt;
return GRAPH_SUCCESS;
} else {
return CalcElementCntByDims(dims, element_cnt);
}
}

///
/// Calculate tensor element num.
/// @param dims dim info
/// @param format tensor format
/// @param data_type data type
/// @param element_cnt element count
/// @return GRAPH_SUCCESS:success
/// other:failed
///
static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format format, DataType data_type,
int64_t &element_cnt) {
const string format_str = TypeUtils::FormatToSerialString(format);
// Check dims
for (size_t i = 0; i < dims.size(); ++i) {
int64_t dim = dims[i];
if (dim < 0) {
GELOGI("It's unknown shape, as dims[%zu]=%ld negative, format=%d(%s).", i, dim, format, format_str.c_str());
element_cnt = kElementCntUnknownShape;
return GRAPH_SUCCESS;
} else if (dim == 0) {
GELOGI("No need calc element count, as dims[%zu]=%ld, format=%d(%s).", i, dim, format, format_str.c_str());
element_cnt = 0;
return GRAPH_SUCCESS;
}
}

graphStatus graph_status;
switch (format) {
case FORMAT_ND:
case FORMAT_MD:
graph_status = CalcElementCntByDims(dims, element_cnt);
break;
case FORMAT_NCHW:
case FORMAT_HWCN:
case FORMAT_NHWC:
case FORMAT_CHWN:
graph_status = CalcElementCntOfFixedDims(dims, format, kDimSize4d, element_cnt);
break;
case FORMAT_C1HWNCoC0:
graph_status = CalcElementCntOfFixedDims(dims, format, kDimSizeC1hwncoc0, element_cnt);
break;
case FORMAT_NC1HWC0:
graph_status = CalcElementCntOfNc1hwc0(dims, data_type, element_cnt);
break;
case FORMAT_FRACTAL_Z:
graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt);
break;
case FORMAT_FRACTAL_NZ:
case FORMAT_FRACTAL_ZZ:
case FORMAT_NDHWC:
case FORMAT_NCDHW:
case FORMAT_DHWCN:
case FORMAT_DHWNC:
case FORMAT_FRACTAL_Z_3D:
case FORMAT_FRACTAL_Z_3D_TRANSPOSE:
case FORMAT_NDC1HWC0:
case FORMAT_FRACTAL_Z_C04:
case FORMAT_FRACTAL_ZN_LSTM:
case FORMAT_NC1HWC0_C04:
graph_status = CalcElementCntByDims(dims, element_cnt);
break;
default:
GELOGE(GRAPH_FAILED, "unsupported format, format=%d(%s).", format, format_str.c_str());
graph_status = GRAPH_FAILED;
break;
}

const string type_str = TypeUtils::DataTypeToSerialString(data_type);
if (graph_status == GRAPH_SUCCESS) {
GELOGD(
"CalcTensorElementCnt end, format=%d(%s),"
" data_type=%d(%s), element_cnt=%ld.",
format, format_str.c_str(), data_type, type_str.c_str(), element_cnt);
} else {
GELOGE(GRAPH_FAILED, "CalcTensorElementCnt failed, format=%d(%s), data_type=%d(%s).", format, format_str.c_str(),
data_type, type_str.c_str());
}
return graph_status;
}

///
/// Calculate tensor mem size.
/// @param shape tensor shape
/// @param format tensor format
/// @param data_type tensor data type
/// @param mem_size -1 means unknown shape,other means mem size
/// @return GRAPH_SUCCESS:success, other:failed
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTensorMemSize(const GeShape &shape,
Format format,
DataType data_type,
int64_t &mem_size) {
const string format_str = TypeUtils::FormatToSerialString(format);
const string type_str = TypeUtils::DataTypeToSerialString(data_type);
uint32_t type_size = 0;
bool result = TypeUtils::GetDataTypeLength(data_type, type_size);
if (!result) {
GELOGE(GRAPH_FAILED, "GetDataTypeLength failed, data_type=%d(%s).", data_type, type_str.c_str());
return GRAPH_FAILED;
}

std::vector<int64_t> dims = shape.GetDims();
int64_t element_cnt = 0;
graphStatus status = CalcTensorElementCnt(dims, format, data_type, element_cnt);
if (status != GRAPH_SUCCESS) {
GELOGE(status, "CalcTensorElementCnt failed, status=%u format=%d(%s) data_type=%d(%s).", status, format,
format_str.c_str(), data_type, type_str.c_str());
return status;
}
// Support unknown shape
if (element_cnt < 0) {
mem_size = kMemSizeUnknownShape;
GELOGD(
"element_cnt is unknown. "
"format=%d(%s), data_type=%d(%s), mem_size=%ld",
format, format_str.c_str(), data_type, type_str.c_str(), mem_size);
return GRAPH_SUCCESS;
}
auto type_size_int64 = static_cast<int64_t>(type_size);
if (CheckMultiplyOverflowInt64(element_cnt, type_size_int64)) {
GELOGE(GRAPH_FAILED, "CalcTensorMemSize overflow, when multiplying %ld and %ld, format=%d(%s), data_type=%d(%s).",
element_cnt, type_size_int64, format, format_str.c_str(), data_type, type_str.c_str());
return GRAPH_FAILED;
}
mem_size = element_cnt * type_size_int64;

GELOGD(
"CalcTensorMemSize end, "
"format=%d(%s), data_type=%d(%s), mem_size=%ld",
format, format_str.c_str(), data_type, type_str.c_str(), mem_size);
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) {
graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp);
if (graph_status != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
// 64-byte alignment, if size is 0, align to 32 bytes
if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) {
GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp);
} else {
size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize;
}
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) {
GeShape output_shape = desc_temp.GetShape();
Format format = desc_temp.GetFormat();
DataType data_type = desc_temp.GetDataType();
int64_t output_mem_size = 0;
graphStatus graph_status = CalcTensorMemSize(output_shape, format, data_type, output_mem_size);
if (graph_status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "CalcTensorMemSize failed!");
return GRAPH_FAILED;
}

if (output_mem_size < 0) {
GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]",
output_mem_size, INT64_MAX);
return GRAPH_FAILED;
}

size_temp = output_mem_size;
return GRAPH_SUCCESS;
}
} // namespace ge

+ 0
- 684
metadef/graph/utils/tuning_utils.cc View File

@@ -1,684 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/tuning_utils.h"
#include "../debug/ge_util.h"
#include "../debug/ge_op_types.h"

namespace ge {
const std::string peer_node_name_attr = "_peerNodeName";
const std::string parent_node_name_attr = "_parentNodeName";
const std::string alias_name_attr = "_aliasName";
const std::string parent_node_attr = "parentNode";
const std::string parent_node_anchor_index_attr = "_parentNodeAnchorIndex";
const std::string tuning_subgraph_prefix = "/aicore_subgraph_";
const std::string non_tuning_subgraph_prefix = "/subgraph_";
const std::set<std::string> kPartitionOpTypes = {PLACEHOLDER, END};
const std::set<std::string> kExeTypes = {DATA, NETOUTPUT};
NodeNametoNodeNameMap TuningUtils::data_2_netoutput_;
NodetoNodeNameMap TuningUtils::data_node_2_netoutput_;
NodetoNodeMap TuningUtils::data_node_2_netoutput_node_;
NodeSet TuningUtils::netoutput_nodes_;
NodeSet TuningUtils::merged_graph_nodes_;
SubgraphCreateOutNode TuningUtils::create_output_;
std::mutex TuningUtils::mutex_;

std::string TuningUtils::PrintCheckLog() {
std::stringstream ss;
ss << "d2n:{";
for (const auto &pair : data_2_netoutput_) {
ss << "data:" << pair.first << "-"
<< "netoutput:" << pair.second;
ss << " | ";
}
ss << "}";
ss << "netoutputs:{";
for (const auto &node : netoutput_nodes_) {
ss << "netoutput:" << node->GetName();
ss << " | ";
}
ss << "}";
return ss.str();
}

std::string TuningUtils::GetNodeNameByAnchor(const Anchor *anchor) {
if (anchor == nullptr) {
GELOGE(GRAPH_FAILED, "Anchor is nullptr");
return "Null";
}
auto node = anchor->GetOwnerNode();
return node == nullptr ? "Null" : node->GetName();
}

// part 1
graphStatus TuningUtils::ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_subgraphs,
std::vector<ComputeGraphPtr> non_tuning_subgraphs, bool exe_flag,
const std::string &path, const std::string &user_path) {
int64_t i = 0;
int64_t j = 0;
std::lock_guard<std::mutex> lock(mutex_);
for (auto &subgraph : tuning_subgraphs) {
create_output_.emplace(subgraph, nullptr);
auto help_info = HelpInfo{i, exe_flag, true, path, user_path};
if (MakeExeGraph(subgraph, help_info) != SUCCESS) {
GELOGE(GRAPH_FAILED, "TUU:subgraph %zu generate exe graph failed", i);
return GRAPH_FAILED;
}
i++;
}

for (auto &subgraph : non_tuning_subgraphs) {
create_output_.emplace(subgraph, nullptr);
auto help_info = HelpInfo{j, true, false, path, user_path};
if (MakeExeGraph(subgraph, help_info) != SUCCESS) {
GELOGE(GRAPH_FAILED, "TUU:non tuning_subgraph %zu generate exe graph failed", j);
return GRAPH_FAILED;
}
j++;
}
create_output_.clear();
return SUCCESS;
}

// +---------------+
// | pld pld |
// | \ / |
// | relu relu |
// | \ / |
// | add |
// | | |
// | end |
// +---------------+
// |
// |
// V
// +---------------+
// | data data |
// | \ / |
// | relu relu |
// | \ / |
// | add |
// | | |
// | netoutput |
// +---------------+
graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) {
GE_CHECK_NOTNULL(exe_graph);
// if not make exe, just dump and return
if (!help_info.exe_flag) {
DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path);
GELOGI("TUU:just return, dump original sub_graph[%s]index[%d]", exe_graph->GetName().c_str(), help_info.index);
return SUCCESS;
}
// modify sub graph
for (NodePtr &node : exe_graph->GetDirectNode()) {
// 1.handle pld
if (node->GetType() == PLACEHOLDER) {
if (HandlePld(node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(),
exe_graph->GetName().c_str());
return FAILED;
}
}
// 2.handle end
if (node->GetType() == END) {
if (HandleEnd(node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(),
exe_graph->GetName().c_str());
return FAILED;
}
}
}
graphStatus ret = exe_graph->TopologicalSorting();
if (ret != SUCCESS) {
GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret);
return ret;
}
// dump subgraphs which modified by us
if (help_info.user_path.empty()) {
DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path);
} else {
GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path);
}
return SUCCESS;
}

void TuningUtils::DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path) {
if (!path.empty()) {
if (is_tuning_graph) {
GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt");
} else {
GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt");
}
} else {
path = "./";
if (is_tuning_graph) {
GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt");
} else {
GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt");
}
}
}

graphStatus TuningUtils::CreateDataNode(NodePtr &node, NodePtr &data_node) {
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
auto data_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), DATA);
GE_CHECK_NOTNULL(data_op_desc);
auto pld_op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(pld_op_desc);
auto output_desc = pld_op_desc->GetOutputDesc(0); // only one output for pld and data
// data inputdesc & outputdesc set as same
if (data_op_desc->AddInputDesc(output_desc) != SUCCESS) {
GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str());
return FAILED;
}
if (data_op_desc->AddOutputDesc(output_desc) != SUCCESS) {
GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str());
return FAILED;
}
data_node = graph->AddNode(data_op_desc);
GE_CHECK_NOTNULL(data_node);
if (data_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed");
return FAILED;
}
return SUCCESS;
}

graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node) {
auto op_desc = data_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);

auto pld_desc = pld->GetOpDesc();
GE_CHECK_NOTNULL(pld_desc);
// inherit
// a. set `end's input node type` as attr
std::string parent_op_type;
if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) {
GELOGE(FAILED, "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str());
return FAILED;
}
(void)AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type);
// b. set `end's input node name` as attr
std::string parent_op_name;
if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) {
GELOGE(FAILED, "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str());
return FAILED;
}
(void)AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name);
// c. set `end's input node's out anchor index` as attr
int parent_node_anchor_index;
if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) {
GELOGE(FAILED, "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str());
return FAILED;
}
(void)AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index);
GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(),
data_node->GetName().c_str(), data_node->GetType().c_str());
// d. set `end node name` as attr
std::string peer_end_name;
if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) {
GELOGE(FAILED, "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str());
return FAILED;
}
(void)AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name);
GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(),
data_node->GetName().c_str(), data_node->GetType().c_str());
return SUCCESS;
}

graphStatus TuningUtils::ChangePld2Data(NodePtr &node, NodePtr &data_node) {
auto type_pld = node->GetType();
auto type_data = data_node->GetType();
if (type_pld != PLACEHOLDER || type_data != DATA) {
GELOGE(FAILED, "TUU:Failed to change node %s from type %s to type %s", node->GetName().c_str(), type_pld.c_str(),
type_data.c_str());
return FAILED;
}
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
std::vector<int> output_map(node->GetAllOutDataAnchorsSize());
for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
output_map[i] = static_cast<int>(i);
}

auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:Failed to replace node %s by node %s error node %u", node->GetName().c_str(),
data_node->GetName().c_str(), ret);
return FAILED;
}

NodeUtils::UnlinkAll(*node);

ret = GraphUtils::RemoveNodeWithoutRelink(graph, node);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str());
return FAILED;
}

GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)", node->GetName().c_str(),
node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str());
return ret;
}

graphStatus TuningUtils::HandlePld(NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
NodePtr data_node = nullptr;

// 1. create data node
if (CreateDataNode(node, data_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
// 2. add necessary info to data_node for recovery whole graph
if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
// 3. replace pld node by data node created before
if (ChangePld2Data(node, data_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
GELOGD("TUU:pld[%s] handle success", node->GetName().c_str());
return SUCCESS;
}

graphStatus TuningUtils::CreateNetOutput(NodePtr &node, NodePtr &out_node) {
GE_CHECK_NOTNULL(node);
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
auto search = create_output_.find(graph);
if (search == create_output_.end()) {
GELOGE(FAILED, "TUU:node %s's owner sub graph %s not exist in create_output map", node->GetName().c_str(),
graph->GetName().c_str());
return FAILED;
}
if (search->second != nullptr) {
out_node = search->second;
GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str());
return SUCCESS;
}
auto out_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), NETOUTPUT);
GE_CHECK_NOTNULL(out_op_desc);
out_node = graph->AddNode(out_op_desc);
GE_CHECK_NOTNULL(out_node);
if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed");
return FAILED;
}
create_output_[graph] = out_node;
return SUCCESS;
}

graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node) {
GE_CHECK_NOTNULL(end);
GE_CHECK_NOTNULL(out_node);
auto op_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
std::vector<std::string> alias_names = {};
(void)AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names);
alias_names.push_back(end->GetName());
(void)AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names);
return SUCCESS;
}

graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) {
GE_CHECK_NOTNULL(end_node);
GE_CHECK_NOTNULL(out_node);
// get end in node is control node or normal node
AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr)
? Anchor::DynamicAnchorCast<Anchor>(end_node->GetInControlAnchor())
: Anchor::DynamicAnchorCast<Anchor>(end_node->GetInDataAnchor(0));
auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1
if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), end_node->GetName().c_str(),
end_node->GetOwnerComputeGraph()->GetName().c_str());
return FAILED;
}
// add edge between `end in node` and `out_node`
if (src_anchor->IsTypeOf<OutDataAnchor>()) {
std::shared_ptr<InDataAnchor> anchor =
ComGraphMakeShared<InDataAnchor>(out_node, out_node->GetAllInDataAnchors().size());
GE_CHECK_NOTNULL(anchor);
out_node->in_data_anchors_.push_back(anchor);
if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(),
end_node->GetOwnerComputeGraph()->GetName().c_str());
return FAILED;
}
auto end_op_desc = end_node->GetOpDesc();
GE_CHECK_NOTNULL(end_op_desc);
auto out_node_op_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(out_node_op_desc);
// end node always has one input
if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0)) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str());
return FAILED;
}
} else if (src_anchor->IsTypeOf<OutControlAnchor>()) {
auto anchor = out_node->GetInControlAnchor();
if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(),
end_node->GetOwnerComputeGraph()->GetName().c_str());
return FAILED;
}
} else {
GELOGE(FAILED, "TUU: node_name:%s, graph_name:%s handled failed", end_node->GetName().c_str(),
end_node->GetOwnerComputeGraph()->GetName().c_str());
return FAILED;
}

return SUCCESS;
}

graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) {
GE_CHECK_NOTNULL(end_node);
GE_CHECK_NOTNULL(out_node);
auto type_end = end_node->GetType();
auto type_out = out_node->GetType();
if (type_end != END || type_out != NETOUTPUT) {
GELOGE(FAILED, "TUU:Failed to change end_node %s from type %s to type %s", end_node->GetName().c_str(),
type_end.c_str(), type_out.c_str());
return FAILED;
}
// link all `end nodes's in node` to this out_node
if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) {
GELOGE(FAILED, "TUU:end_node [%s] LinkEnd2NetOutput failed.", end_node->GetName().c_str());
return FAILED;
}
// remove `end node`
NodeUtils::UnlinkAll(*end_node);
auto graph = end_node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
if (GraphUtils::RemoveNodeWithoutRelink(graph, end_node) != SUCCESS) {
GELOGE(FAILED, "TUU:end node [%s] RemoveNodeWithoutRelink failed.", end_node->GetName().c_str());
return FAILED;
}
return SUCCESS;
}

graphStatus TuningUtils::HandleEnd(NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
NodePtr out_node = nullptr;

// 1. create net_output node , add only one NetOutput node to one subgraph
if (CreateNetOutput(node, out_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
// 2. add necessary info to out_node for recovery whole graph
if (AddAttrToNetOutputForMergeGraph(node, out_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
// 3. replace all end nodes by one output node created before
if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
GELOGD("TUU:end[%s] handle success", node->GetName().c_str());
return SUCCESS;
}

// part 2
graphStatus TuningUtils::ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph) {
// 1. get all subgraph object
std::vector<ComputeGraphPtr> graphs;
// options format like {index:"subgraph_path"}
for (const auto &pair : options) {
ComputeGraphPtr compute_graph = ComGraphMakeShared<ComputeGraph>(std::to_string(pair.first));
if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), *compute_graph)) {
GELOGE(FAILED, "TUU:load graph from file failed");
}
graphs.push_back(compute_graph);
}
// 2. merge graph
ComputeGraphPtr merged_graph = ComGraphMakeShared<ComputeGraph>("whole_graph_after_tune");
GE_CHECK_NOTNULL(merged_graph);
if (MergeAllSubGraph(graphs, merged_graph) != SUCCESS) {
GELOGE(FAILED, "TUU:MergeGraph failed");
return FAILED;
}
// 3. set parent graph
for (const auto &node : merged_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
if (node->SetOwnerComputeGraph(merged_graph) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:node %s set owner graph failed", node->GetName().c_str());
return FAILED;
}
}
graph = GraphUtils::CreateGraphFromComputeGraph(merged_graph);
return SUCCESS;
}

// +----------------------------------+
// | const const |
// | \ / |
// | netoutput(end,end) |
// +----------------------------------+
// +
// +----------------------------------+
// | data(pld) data(pld) |
// | \ / |
// | relu relu |
// | \ / |
// | \ / |
// | add |
// | | |
// | netoutput(end) |
// +----------------------------------+
// +
// +----------------------------------+
// | data(pld) |
// | / |
// | netoutput |
// +----------------------------------+
// |
// |
// V
// +----------------------------------+
// | const const |
// | \ / |
// | relu relu |
// | \ / |
// | \ / |
// | add |
// | | |
// | netoutput |
// +----------------------------------+
graphStatus TuningUtils::MergeAllSubGraph(std::vector<ComputeGraphPtr> &subgraphs,
ComputeGraphPtr &output_merged_compute_graph) {
GE_CHECK_NOTNULL(output_merged_compute_graph);
// 1. handle all subgraphs
for (auto &subgraph : subgraphs) {
Status ret_status = MergeSubGraph(subgraph);
if (ret_status != SUCCESS) {
GELOGE(ret_status, "TUU:subgraph %s merge failed", subgraph->GetName().c_str());
return ret_status;
}
}

for (const auto &node : merged_graph_nodes_) {
(void)output_merged_compute_graph->AddNode(node);
GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str());
}

// 2. remove data and output node added by us
if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) {
GELOGE(FAILED, "TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str());
return FAILED;
}
graphStatus ret = output_merged_compute_graph->TopologicalSorting();
if (ret != SUCCESS) {
GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", output_merged_compute_graph->GetName().c_str(), ret);
return ret;
}
GELOGD("TUU:Print-%s", PrintCheckLog().c_str());
GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str());
return SUCCESS;
}

graphStatus TuningUtils::MergeSubGraph(ComputeGraphPtr &subgraph) {
for (auto &node : subgraph->GetDirectNode()) {
if (kPartitionOpTypes.count(node->GetType()) > 0) {
GELOGE(FAILED, "TUU:subgraph passed in should not contain nodes of end or pld type");
return FAILED;
}
// handle data converted from pld node
if (node->GetType() == DATA) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
std::string peer_out_name;
bool has_valid_str = (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty());
if (has_valid_str) {
std::lock_guard<std::mutex> lock(mutex_);
data_2_netoutput_.emplace(op_desc->GetName(), peer_out_name);
data_node_2_netoutput_.emplace(node, peer_out_name);
continue;
}
}
// handle netoutput converted from end node
if (node->GetType() == NETOUTPUT) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
std::vector<string> out_alias_name;
bool has_valid_str =
(AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty());
if (has_valid_str) {
std::lock_guard<std::mutex> lock(mutex_);
netoutput_nodes_.insert(node);
}
}
{
std::lock_guard<std::mutex> lock(mutex_);
merged_graph_nodes_.emplace(node);
}
GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str());
}
GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str());
return SUCCESS;
}

graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(graph);
// 1. traverse
for (auto &pair : data_node_2_netoutput_) {
auto data_node = pair.first;
GE_CHECK_NOTNULL(data_node);
auto netoutput_name = pair.second;
auto netoutput_node = graph->FindNode(netoutput_name);
GE_CHECK_NOTNULL(netoutput_node);
data_node_2_netoutput_node_.emplace(data_node, netoutput_node);
// 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor`
AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr)
? Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutControlAnchor())
: Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutDataAnchor(0));
AnchorPtr net_output_in_anchor = nullptr;
AnchorPtr src_out_anchor = nullptr;
if (GetInAndOutAnchorPair(data_node, netoutput_node, net_output_in_anchor, src_out_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:get out node:%s 's in anchor related with data node:%s failed",
netoutput_node->GetName().c_str(), data_node->GetName().c_str());
return FAILED;
}
// 3. relink
if (GraphUtils::RemoveEdge(src_out_anchor, net_output_in_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(),
GetNodeNameByAnchor(net_output_in_anchor.get()).c_str(), net_output_in_anchor->GetIdx(),
data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
GE_CHECK_NOTNULL(data_out_anchor);
for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) {
if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(),
GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(),
data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s",
GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(),
GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(),
data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str());
return FAILED;
}
}
}
// 4. remove out nodes added by us
for (auto &node : netoutput_nodes_) {
NodeUtils::UnlinkAll(*node);
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str());
return FAILED;
}
GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str());
}
return SUCCESS;
}

graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor,
AnchorPtr &src_out_anchor) {
// 1. get `data parent node name`, i.e. `netoutput input node name`
std::string netoutput_input_name;
auto op_desc = data_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (!AttrUtils::GetStr(op_desc, parent_node_name_attr, netoutput_input_name)) {
GELOGE(FAILED, "TUU:Failed to get parent node attr from node %s", op_desc->GetName().c_str());
return FAILED;
}
// 2. find index
int parent_node_anchor_index;
if (!AttrUtils::GetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index)) {
GELOGE(FAILED, "TUU:Failed to get parent node anchor index attr from node %s", op_desc->GetName().c_str());
return FAILED;
}
// 3.find in data or ctrl anchor by 1&2 step
for (auto &in_anchor : out_node->GetAllInAnchors()) {
GE_CHECK_NOTNULL(in_anchor);
for (auto &src_anchor : in_anchor->GetPeerAnchors()) { // get all peer anchors for ctrl
GE_CHECK_NOTNULL(src_anchor);
auto src_node = src_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
if (src_node->GetName() == netoutput_input_name && src_anchor->GetIdx() == parent_node_anchor_index) {
dest_in_anchor = in_anchor;
src_out_anchor = src_anchor;
GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s",
out_node->GetName().c_str(), dest_in_anchor->GetIdx(), netoutput_input_name.c_str(),
parent_node_anchor_index, data_node->GetName().c_str());
break;
}
}
}
GE_CHECK_NOTNULL(dest_in_anchor);
GE_CHECK_NOTNULL(src_out_anchor);
return SUCCESS;
}

} // namespace ge

+ 0
- 448
metadef/graph/utils/type_utils.cc View File

@@ -1,448 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/utils/type_utils.h"
#include "debug/ge_util.h"

using domi::domiTensorFormat_t;

namespace ge {
static const std::map<Format, std::string> kFormatToStringMap = {
{FORMAT_NCHW, "NCHW"},
{FORMAT_NHWC, "NHWC"},
{FORMAT_ND, "ND"},
{FORMAT_NC1HWC0, "NC1HWC0"},
{FORMAT_FRACTAL_Z, "FRACTAL_Z"},
{FORMAT_NC1C0HWPAD, "NC1C0HWPAD"},
{FORMAT_NHWC1C0, "NHWC1C0"},
{FORMAT_FSR_NCHW, "FSR_NCHW"},
{FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"},
{FORMAT_C1HWNC0, "C1HWNC0"},
{FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"},
{FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"},
{FORMAT_NC1HWC0_C04, "NC1HWC0_C04"},
{FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"},
{FORMAT_CHWN, "CHWN"},
{FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"},
{FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"},
{FORMAT_BN_WEIGHT, "BN_WEIGHT"},
{FORMAT_FILTER_HWCK, "FILTER_HWCK"},
{FORMAT_HWCN, "HWCN"},
{FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"},
{FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"},
{FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"},
{FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"},
{FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"},
{FORMAT_MD, "MD"},
{FORMAT_NDHWC, "NDHWC"},
{FORMAT_NCDHW, "NCDHW"},
{FORMAT_DHWCN, "DHWCN"},
{FORMAT_DHWNC, "DHWNC"},
{FORMAT_NDC1HWC0, "NDC1HWC0"},
{FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"},
{FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"},
{FORMAT_C1HWNCoC0, "C1HWNCoC0"},
{FORMAT_FRACTAL_NZ, "FRACTAL_NZ"},
{FORMAT_CN, "CN"},
{FORMAT_NC, "NC"},
{FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"},
{FORMAT_FRACTAL_Z_G, "FRACTAL_Z_G"},
{FORMAT_RESERVED, "FORMAT_RESERVED"},
{FORMAT_ALL, "ALL"}};

static const std::map<domiTensorFormat_t, Format> kDomiFormatToGeFormat = {
{domi::DOMI_TENSOR_NCHW, FORMAT_NCHW},
{domi::DOMI_TENSOR_NHWC, FORMAT_NHWC},
{domi::DOMI_TENSOR_ND, FORMAT_ND},
{domi::DOMI_TENSOR_NC1HWC0, FORMAT_NC1HWC0},
{domi::DOMI_TENSOR_FRACTAL_Z, FORMAT_FRACTAL_Z},
{domi::DOMI_TENSOR_NC1C0HWPAD, FORMAT_NC1C0HWPAD},
{domi::DOMI_TENSOR_NHWC1C0, FORMAT_NHWC1C0},
{domi::DOMI_TENSOR_FSR_NCHW, FORMAT_FSR_NCHW},
{domi::DOMI_TENSOR_FRACTAL_DECONV, FORMAT_FRACTAL_DECONV},
{domi::DOMI_TENSOR_BN_WEIGHT, FORMAT_BN_WEIGHT},
{domi::DOMI_TENSOR_CHWN, FORMAT_CHWN},
{domi::DOMI_TENSOR_FILTER_HWCK, FORMAT_FILTER_HWCK},
{domi::DOMI_TENSOR_NDHWC, FORMAT_NDHWC},
{domi::DOMI_TENSOR_NCDHW, FORMAT_NCDHW},
{domi::DOMI_TENSOR_DHWCN, FORMAT_DHWCN},
{domi::DOMI_TENSOR_DHWNC, FORMAT_DHWNC},
{domi::DOMI_TENSOR_RESERVED, FORMAT_RESERVED}};

static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0",
"FRACTAL_Z",
"NC1C0HWPAD",
"NHWC1C0",
"FRACTAL_DECONV",
"C1HWNC0",
"FRACTAL_DECONV_TRANSPOSE",
"FRACTAL_DECONV_SP_STRIDE_TRANS",
"NC1HWC0_C04",
"FRACTAL_Z_C04",
"FRACTAL_DECONV_SP_STRIDE8_TRANS",
"NC1KHKWHWC0",
"C1HWNCoC0",
"FRACTAL_ZZ",
"FRACTAL_NZ",
"NDC1HWC0",
"FORMAT_FRACTAL_Z_3D",
"FORMAT_FRACTAL_Z_3D_TRANSPOSE",
"FORMAT_FRACTAL_ZN_LSTM",
"FORMAT_FRACTAL_Z_G"};

static const std::map<std::string, Format> kDataFormatMap = {
{"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}};

static const std::map<std::string, Format> kStringToFormatMap = {
{"NCHW", FORMAT_NCHW},
{"NHWC", FORMAT_NHWC},
{"ND", FORMAT_ND},
{"NC1HWC0", FORMAT_NC1HWC0},
{"FRACTAL_Z", FORMAT_FRACTAL_Z},
{"NC1C0HWPAD", FORMAT_NC1C0HWPAD},
{"NHWC1C0", FORMAT_NHWC1C0},
{"FSR_NCHW", FORMAT_FSR_NCHW},
{"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV},
{"C1HWNC0", FORMAT_C1HWNC0},
{"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE},
{"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS},
{"NC1HWC0_C04", FORMAT_NC1HWC0_C04},
{"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04},
{"CHWN", FORMAT_CHWN},
{"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS},
{"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0},
{"BN_WEIGHT", FORMAT_BN_WEIGHT},
{"FILTER_HWCK", FORMAT_FILTER_HWCK},
{"HWCN", FORMAT_HWCN},
{"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS},
{"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS},
{"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE},
{"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT},
{"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS},
{"MD", FORMAT_MD},
{"C1HWNCoC0", FORMAT_C1HWNCoC0},
{"FRACTAL_NZ", FORMAT_FRACTAL_NZ},
{"NDHWC", FORMAT_NDHWC},
{"NCDHW", FORMAT_NCDHW},
{"DHWCN", FORMAT_DHWCN},
{"DHWNC", FORMAT_DHWNC},
{"NDC1HWC0", FORMAT_NDC1HWC0},
{"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D},
{"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE},
{"CN", FORMAT_CN},
{"NC", FORMAT_NC},
{"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM},
{"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G},
{"FORMAT_RESERVED", FORMAT_RESERVED},
{"ALL", FORMAT_ALL},
{"NULL", FORMAT_NULL}};

static const std::map<DataType, std::string> kDataTypeToStringMap = {
{DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set.
{DT_FLOAT, "DT_FLOAT"}, // float type
{DT_FLOAT16, "DT_FLOAT16"}, // fp16 type
{DT_INT8, "DT_INT8"}, // int8 type
{DT_INT16, "DT_INT16"}, // int16 type
{DT_UINT16, "DT_UINT16"}, // uint16 type
{DT_UINT8, "DT_UINT8"}, // uint8 type
{DT_INT32, "DT_INT32"}, // uint32 type
{DT_INT64, "DT_INT64"}, // int64 type
{DT_UINT32, "DT_UINT32"}, // unsigned int32
{DT_UINT64, "DT_UINT64"}, // unsigned int64
{DT_BOOL, "DT_BOOL"}, // bool type
{DT_DOUBLE, "DT_DOUBLE"}, // double type
{DT_DUAL, "DT_DUAL"}, // dual output type
{DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type
{DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type
{DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type
{DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type
{DT_QINT8, "DT_QINT8"}, // qint8 type
{DT_QINT16, "DT_QINT16"}, // qint16 type
{DT_QINT32, "DT_QINT32"}, // qint32 type
{DT_QUINT8, "DT_QUINT8"}, // quint8 type
{DT_QUINT16, "DT_QUINT16"}, // quint16 type
{DT_RESOURCE, "DT_RESOURCE"}, // resource type
{DT_STRING_REF, "DT_STRING_REF"}, // string ref type
{DT_STRING, "DT_STRING"}, // string type
};

static const std::map<std::string, DataType> kStringTodataTypeMap = {
{"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set.
{"DT_FLOAT", DT_FLOAT}, // float type
{
"DT_FLOAT16",
DT_FLOAT16,
}, // fp16 type
{"DT_INT8", DT_INT8}, // int8 type
{"DT_INT16", DT_INT16}, // int16 type
{"DT_UINT16", DT_UINT16}, // uint16 type
{"DT_UINT8", DT_UINT8}, // uint8 type
{"DT_INT32", DT_INT32}, // uint32 type
{"DT_INT64", DT_INT64}, // int64 type
{"DT_UINT32", DT_UINT32}, // unsigned int32
{"DT_UINT64", DT_UINT64}, // unsigned int64
{"DT_BOOL", DT_BOOL}, // bool type
{"DT_DOUBLE", DT_DOUBLE}, // double type
{"DT_DUAL", DT_DUAL}, // dual output type
{"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type
{"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type
{"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type
{"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type
{"DT_QINT8", DT_QINT8}, // qint8 type
{"DT_QINT16", DT_QINT16}, // qint16 type
{"DT_QINT32", DT_QINT32}, // qint32 type
{"DT_QUINT8", DT_QUINT8}, // quint8 type
{"DT_QUINT16", DT_QUINT16}, // quint16 type
{"DT_RESOURCE", DT_RESOURCE}, // resource type
{"DT_STRING_REF", DT_STRING_REF}, // string ref type
{"DT_STRING", DT_STRING}, // string type
};

static const std::map<ge::DataType, uint32_t> kDataTypeToLength = {
{DT_BOOL, sizeof(bool)},
{DT_INT64, sizeof(int64_t)},
{DT_UINT64, sizeof(int64_t)},
{DT_FLOAT, sizeof(float)},
{DT_INT32, sizeof(int32_t)},
{DT_UINT32, sizeof(int32_t)},
{DT_INT8, sizeof(char)},
{DT_UINT8, sizeof(char)},
{DT_INT16, sizeof(int16_t)},
{DT_UINT16, sizeof(int16_t)},
{DT_FLOAT16, sizeof(int16_t)},
{DT_DOUBLE, sizeof(double)},
{DT_DUAL, sizeof(float) + sizeof(int8_t)},
{DT_DUAL_SUB_INT8, sizeof(int8_t)},
{DT_DUAL_SUB_UINT8, sizeof(uint8_t)},
{DT_COMPLEX64, sizeof(int64_t)},
{DT_COMPLEX128, sizeof(int64_t) * 2},
{DT_QINT8, sizeof(int8_t)},
{DT_QINT16, sizeof(int16_t)},
{DT_QINT32, sizeof(int32_t)},
{DT_QUINT8, sizeof(uint8_t)},
{DT_QUINT16, sizeof(uint16_t)},
{DT_STRING_REF, sizeof(uint64_t) * 2},
{DT_STRING, sizeof(uint64_t)},
{DT_RESOURCE, sizeof(uint64_t)},
};

static const std::map<domi::FrameworkType, std::string> kFmkTypeToString = {
{domi::CAFFE, "caffe"}, {domi::MINDSPORE, "mindspore"}, {domi::TENSORFLOW, "tensorflow"},
{domi::ANDROID_NN, "android_nn"}, {domi::ONNX, "onnx"}, {domi::FRAMEWORK_RESERVED, "framework_reserved"},
};

static const std::map<domi::ImplyType, std::string> kImplyTypeToString = {
{domi::ImplyType::BUILDIN, "buildin"}, {domi::ImplyType::TVM, "tvm"}, {domi::ImplyType::CUSTOM, "custom"},
{domi::ImplyType::AI_CPU, "ai_cpu"}, {domi::ImplyType::CCE, "cce"}, {domi::ImplyType::GELOCAL, "gelocal"},
{domi::ImplyType::HCCL, "hccl"}, {domi::ImplyType::INVALID, "invalid"}};

std::string TypeUtils::ImplyTypeToSerialString(domi::ImplyType imply_type) {
auto it = kImplyTypeToString.find(imply_type);
if (it != kImplyTypeToString.end()) {
return it->second;
} else {
GELOGE(GRAPH_FAILED, "ImplyTypeToSerialString: imply_type not support %u", imply_type);
return "UNDEFINED";
}
}

bool TypeUtils::IsDataTypeValid(DataType dt) {
uint32_t num = static_cast<uint32_t>(dt);
GE_CHK_BOOL_EXEC((num <= DT_UNDEFINED), return false, "The DataType is invalid");
return true;
}

std::string TypeUtils::DataTypeToSerialString(DataType data_type) {
auto it = kDataTypeToStringMap.find(data_type);
if (it != kDataTypeToStringMap.end()) {
return it->second;
} else {
GELOGE(GRAPH_FAILED, "DataTypeToSerialString: datatype not support %u", data_type);
return "UNDEFINED";
}
}

DataType TypeUtils::SerialStringToDataType(const std::string &str) {
auto it = kStringTodataTypeMap.find(str);
if (it != kStringTodataTypeMap.end()) {
return it->second;
} else {
GELOGE(GRAPH_FAILED, "SerialStringToDataType: datatype not support %s", str.c_str());
return DT_UNDEFINED;
}
}

bool TypeUtils::IsFormatValid(Format format) {
uint32_t num = static_cast<uint32_t>(format);
GE_CHK_BOOL_EXEC((num <= FORMAT_RESERVED), return false, "The Format is invalid");
return true;
}

bool TypeUtils::IsInternalFormat(Format format) {
std::string serial_format = FormatToSerialString(format);
auto iter = kInternalFormat.find(serial_format);
bool result = (iter == kInternalFormat.end()) ? false : true;
return result;
}

std::string TypeUtils::FormatToSerialString(Format format) {
auto it = kFormatToStringMap.find(format);
if (it != kFormatToStringMap.end()) {
return it->second;
} else {
GELOGE(GRAPH_FAILED, "Format not support %u", format);
return "RESERVED";
}
}
Format TypeUtils::SerialStringToFormat(const std::string &str) {
auto it = kStringToFormatMap.find(str);
if (it != kStringToFormatMap.end()) {
return it->second;
} else {
GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str());
return FORMAT_RESERVED;
}
}

Format TypeUtils::DataFormatToFormat(const std::string &str) {
auto it = kDataFormatMap.find(str);
if (it != kDataFormatMap.end()) {
return it->second;
} else {
GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str());
return FORMAT_RESERVED;
}
}

Format TypeUtils::DomiFormatToFormat(domi::domiTensorFormat_t domi_format) {
auto it = kDomiFormatToGeFormat.find(domi_format);
if (it != kDomiFormatToGeFormat.end()) {
return it->second;
}
GELOGE(GRAPH_FAILED, "do not find domi Format %d from map", domi_format);
return FORMAT_RESERVED;
}

std::string TypeUtils::FmkTypeToSerialString(domi::FrameworkType fmk_type) {
auto it = kFmkTypeToString.find(fmk_type);
if (it != kFmkTypeToString.end()) {
return it->second;
} else {
GELOGW("Framework type not support %d.", fmk_type);
return "";
}
}

static inline void CopyDataFromBuffer(vector<uint8_t> &data, const Buffer &buffer) {
data.clear();
if (buffer.GetData() != nullptr && buffer.GetSize() != 0) {
data.assign(buffer.GetData(), buffer.GetData() + buffer.GetSize());
}
}

graphStatus Usr2DefQuantizeFactor(const UsrQuantizeFactor &usr, QuantizeFactor &def) {
def.scale_mode = uint32_t(usr.scale_mode);
def.set_scale_value(usr.scale_value.data(), usr.scale_value.size());
def.scale_offset = usr.scale_offset;
def.set_offset_data_value(usr.offset_data_value.data(), usr.offset_data_value.size());
def.offset_data_offset = usr.offset_data_offset;
def.set_offset_weight_value(usr.offset_weight_value.data(), usr.offset_weight_value.size());
def.offset_weight_offset = usr.offset_weight_offset;
def.set_offset_pad_value(usr.offset_pad_value.data(), usr.offset_pad_value.size());
def.offset_pad_offset = usr.offset_pad_offset;
return GRAPH_SUCCESS;
}
graphStatus Def2UsrQuantizeFactor(const QuantizeFactor &def, UsrQuantizeFactor &usr) {
usr.scale_mode = UsrQuantizeScaleMode(def.scale_mode);
CopyDataFromBuffer(usr.scale_value, def.scale_value);
usr.scale_offset = def.scale_offset;
CopyDataFromBuffer(usr.offset_data_value, def.offset_data_value);
usr.offset_data_offset = def.offset_data_offset;
CopyDataFromBuffer(usr.offset_weight_value, def.offset_weight_value);
usr.offset_weight_offset = def.offset_weight_offset;
CopyDataFromBuffer(usr.offset_pad_value, def.offset_pad_value);
usr.offset_pad_offset = def.offset_pad_offset;
return GRAPH_SUCCESS;
}
graphStatus Usr2DefUsrQuantizeCalcFactor(const UsrQuantizeCalcFactor &usr, QuantizeCalcFactor &def) {
def.set_offsetw(usr.offsetw.data(), usr.offsetw.size());
def.offsetw_offset = usr.offsetw_offset;
def.set_offsetd(usr.offsetd.data(), usr.offsetd.size());
def.offsetd_offset = usr.offsetd_offset;
def.set_scalereq(usr.scalereq.data(), usr.scalereq.size());
def.scaledreq_offset = usr.scaledreq_offset;
def.set_offsetdnext(usr.offsetdnext.data(), usr.offsetdnext.size());
def.offsetdnext_offset = usr.offsetdnext_offset;
return GRAPH_SUCCESS;
}
graphStatus Def2UsrQuantizeCalcFactor(const QuantizeCalcFactor &def, UsrQuantizeCalcFactor &usr) {
CopyDataFromBuffer(usr.offsetw, def.offsetw);
usr.offsetw_offset = def.offsetw_offset;
CopyDataFromBuffer(usr.offsetd, def.offsetd);
usr.offsetd_offset = def.offsetd_offset;
CopyDataFromBuffer(usr.scalereq, def.scalereq);
usr.scaledreq_offset = def.scaledreq_offset;
CopyDataFromBuffer(usr.offsetdnext, def.offsetdnext);
usr.offsetdnext_offset = def.offsetdnext_offset;
return GRAPH_SUCCESS;
}
graphStatus TypeUtils::Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def) {
def.quantize_algo = uint32_t(usr.quantize_algo);
def.scale_type = uint32_t(usr.scale_type);
GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.quantize_param, def.quantize_param),
"Usr2DefQuantizeFactor quantize_param failed");
GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.dequantize_param, def.dequantize_param),
"Usr2DefQuantizeFactor dequantize_param failed");
GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.requantize_param, def.requantize_param),
"Usr2DefQuantizeFactor requantize_param failed");
GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefUsrQuantizeCalcFactor(usr.quantizecalc_param, def.quantizecalc_param),
"Usr2DefQuantizeFactor quantizecalc_param failed");
return GRAPH_SUCCESS;
}
graphStatus TypeUtils::Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr) {
usr.quantize_algo = UsrQuantizeAlgorithm(def.quantize_algo);
usr.scale_type = UsrQuantizeScaleType(def.scale_type);
GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.quantize_param, usr.quantize_param),
"Def2UsrQuantizeFactor quantize_param failed");
GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.dequantize_param, usr.dequantize_param),
"Def2UsrQuantizeFactor dequantize_param failed");
GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.requantize_param, usr.requantize_param),
"Def2UsrQuantizeFactor requantize_param failed");
GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeCalcFactor(def.quantizecalc_param, usr.quantizecalc_param),
"Def2UsrQuantizeCalcFactor quantizecalc_param failed");
return GRAPH_SUCCESS;
}
bool TypeUtils::GetDataTypeLength(ge::DataType data_type, uint32_t &length) {
auto it = kDataTypeToLength.find(data_type);
if (it != kDataTypeToLength.end()) {
length = it->second;
return true;
} else {
GELOGE(GRAPH_FAILED, "data_type not support %d", data_type);
return false;
}
}
bool TypeUtils::CheckUint64MulOverflow(uint64_t a, uint32_t b) {
// Not overflow
if (a == 0) {
return false;
}
if ((ULLONG_MAX / a) >= b) {
return false;
}
return true;
}
} // namespace ge

+ 0
- 75
metadef/inc/external/graph/attr_value.h View File

@@ -1,75 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_ATTR_VALUE_H_
#define INC_EXTERNAL_GRAPH_ATTR_VALUE_H_

#include <map>
#include <memory>
#include <string>
#include <vector>

#include "./ge_error_codes.h"

using std::make_shared;
using std::map;
using std::pair;
using std::string;
using std::to_string;
using std::unique_ptr;
using std::vector;

namespace ge {
class AttrValueImpl;
/*lint -e148*/
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue {
public:
using INT = int64_t;
using FLOAT = float;
using STR = std::string;

AttrValue();
~AttrValue() = default;

// GetValue, not list type
template <typename T, typename DT>
graphStatus GetValue(DT &val) const {
T valGet;
auto status = GetValue(valGet);
if (status != GRAPH_SUCCESS) {
return status;
}
val = DT(valGet);
return GRAPH_SUCCESS;
}

template <typename T, typename DT>
static T CreateFrom(DT &&val) {
return val;
}

std::shared_ptr<AttrValueImpl> impl;

private:
#define VALUE_SET_GET_DEC(DT) graphStatus GetValue(DT &val) const;
VALUE_SET_GET_DEC(AttrValue::STR)
VALUE_SET_GET_DEC(AttrValue::INT)
VALUE_SET_GET_DEC(AttrValue::FLOAT)
#undef VALUE_SET_GET_DEC
};
/*lint +e148*/
} // namespace ge
#endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_

+ 0
- 38
metadef/inc/external/graph/ge_error_codes.h View File

@@ -1,38 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_
#define INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_

namespace ge {
#ifdef HOST_VISIBILITY
#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_HOST_VISIBILITY
#endif
#ifdef DEV_VISIBILITY
#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_DEV_VISIBILITY
#endif

using graphStatus = uint32_t;
const graphStatus GRAPH_FAILED = 0xFFFFFFFF;
const graphStatus GRAPH_SUCCESS = 0;
const graphStatus GRAPH_PARAM_INVALID = 50331649;
} // namespace ge

#endif // INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_

+ 0
- 81
metadef/inc/external/graph/graph.h View File

@@ -1,81 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_GRAPH_H_
#define INC_EXTERNAL_GRAPH_GRAPH_H_

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "./operator.h"

namespace ge {
class GraphImpl;

using GraphImplPtr = std::shared_ptr<GraphImpl>;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph {
friend class GraphUtils;

public:
explicit Graph(const std::string &name);

Graph() = default;

~Graph() = default;

Graph &SetInputs(const std::vector<Operator> &inputs);

Graph &SetOutputs(const std::vector<Operator> &outputs);

Graph &SetOutputs(const std::vector<std::pair<Operator, std::vector<size_t>>> &output_indexs);

Graph &SetOutputs(const std::vector<std::pair<ge::Operator, std::string>> &outputs);

Graph &SetTargets(const std::vector<Operator> &targets);

bool IsValid() const;

graphStatus AddOp(const ge::Operator &op);

graphStatus FindOpByName(const string &name, ge::Operator &op) const;

graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const;

graphStatus GetAllOpName(std::vector<string> &op_name) const;

graphStatus SaveToFile(const string &file_name) const;

graphStatus LoadFromFile(const string &file_name);

const std::string &GetName() const;

///
/// Set is need train iteration.
/// If set true, it means this graph need to be run iteration some
/// times(according variant "npu_runconfig/iterations_per_loop").
/// @param need_iteration need_iteration:whether to set iteration or not
///
void SetNeedIteration(bool need_iteration);

private:
GraphImplPtr impl_{nullptr};
};
} // namespace ge

#endif // INC_EXTERNAL_GRAPH_GRAPH_H_

+ 0
- 76
metadef/inc/external/graph/inference_context.h View File

@@ -1,76 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_
#define INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_

#include <memory>
#include <string>
#include <vector>

#include "./tensor.h"
#include "./types.h"

namespace ge {
class InferenceContext;
using InferenceContextPtr = std::shared_ptr<InferenceContext>;

class ShapeAndTypeImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ShapeAndType {
public:
ShapeAndType();
~ShapeAndType() = default;

ShapeAndType(const Shape &shape, DataType dataType);

void SetShape(const Shape &shape);

void SetType(DataType dataType);

Shape GetShape() const;

DataType GetDataType() const;

private:
std::shared_ptr<ShapeAndTypeImpl> shape_and_type_impl_;
};

class InferenceContextImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext {
public:
~InferenceContext() = default;
InferenceContext(const InferenceContext &context) = delete;
InferenceContext(const InferenceContext &&context) = delete;
InferenceContext &operator=(const InferenceContext &context) = delete;
InferenceContext &operator=(const InferenceContext &&context) = delete;

void SetInputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types);
const std::vector<std::vector<ShapeAndType>> &GetInputHandleShapesAndTypes() const;
const std::vector<std::vector<ShapeAndType>> &GetOutputHandleShapesAndTypes() const;
void SetOutputHandleShapesAndTypes(const std::vector<std::vector<ShapeAndType>> &shapes_and_types);
void SetOutputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types);

void SetMarks(const std::vector<std::string> &marks);
const std::vector<std::string> &GetMarks() const;

static std::unique_ptr<InferenceContext> Create();

private:
explicit InferenceContext(std::unique_ptr<InferenceContextImpl> &impl);
std::shared_ptr<InferenceContextImpl> inference_context_impl_;
};
} // namespace ge
#endif // INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_

+ 0
- 289
metadef/inc/external/graph/operator.h View File

@@ -1,289 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_OPERATOR_H_
#define INC_EXTERNAL_GRAPH_OPERATOR_H_

#include <functional>
#include <map>
#include <memory>
#include <string>
#include <vector>

#include "./ge_error_codes.h"
#include "./inference_context.h"
#include "./tensor.h"

#ifndef USER_GE_LOGI
#define USER_GE_LOGI(...)
#endif // USER_GE_LOGI

#ifndef USER_GE_LOGW
#define USER_GE_LOGW(...)
#endif // USER_GE_LOGW

#ifndef USER_GE_LOGE
#define USER_GE_LOGE(...)
#endif // USER_GE_LOGE

#define DYNAMIC_OUTPUT_TD_NUM(name) ("__dynamic_output_" + name + "_cnt")
#define DYNAMIC_INPUT_TD_NUM(name) ("__dynamic_input_" + name + "_cnt")

namespace ge {
class Operator;
class OperatorImpl;
class NodeUtils;
class NamedAttrs;
class Graph;
class AttrValue;
class Node;

using SubgraphBuilder = std::function<Graph()>;
using OperatorImplPtr = std::shared_ptr<OperatorImpl>;
using OperatorPtr = std::shared_ptr<Operator>;

class OpIO;
using OutHandler = std::shared_ptr<OpIO>;
using InHandler = std::shared_ptr<OpIO>;

using std::function;
using std::shared_ptr;
using std::string;

/*lint -e148*/
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {
public:
friend class OperatorImpl;
friend class GraphBuilderImpl;
friend class NodeUtils;

using OpInt = int64_t;
using OpFloat = float;
using OpString = string;
using OpBool = bool;
using OpTensor = Tensor;
using OpType = ge::DataType;
using OpNamedAttrs = ge::NamedAttrs;
using OpListInt = std::vector<int64_t>;
using OpListFloat = std::vector<float>;
using OpListString = std::vector<string>;
using OpListBool = std::vector<bool>;
using OpListTensor = std::vector<Tensor>;
using OpBytes = std::vector<uint8_t>;
using OpListListInt = std::vector<std::vector<int64_t>>;
using OpListType = std::vector<ge::DataType>;
using OpListNamedAttrs = std::vector<ge::NamedAttrs>;

Operator() {}

explicit Operator(const string &type);

Operator(const string &name, const string &type); // lint !e148

virtual ~Operator() = default;

bool IsEmpty() const;

string GetName() const;

string GetOpType() const;

// Only has one output index = 0
Operator &SetInput(const string &dst_name, const Operator &src_oprt);

Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148

Operator &SetInput(const string &dst_name, const Operator &src_oprt, uint32_t index);

Operator &AddControlInput(const Operator &src_oprt);

graphStatus GetInputConstData(const string &dst_name, Tensor &data) const;

TensorDesc GetInputDesc(const string &name) const;

TensorDesc GetInputDesc(uint32_t index) const;

int GetDynamicOutputNum(const string &name) const;

int GetDynamicInputNum(const string &name) const;

graphStatus TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const;

graphStatus UpdateInputDesc(const string &name, const TensorDesc &tensor_desc);

TensorDesc GetOutputDesc(const string &name) const;

TensorDesc GetOutputDesc(uint32_t index) const;

graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); // lint !e148

TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const;

graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148

TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const;

graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148

graphStatus InferShapeAndType(); // lint !e148

void SetInferenceContext(const InferenceContextPtr &inference_context);
InferenceContextPtr GetInferenceContext() const;

graphStatus VerifyAllAttr(bool disable_common_verifier = false); // lint !e148

size_t GetInputsSize() const;

size_t GetOutputsSize() const;

const std::map<std::string, std::string> GetAllAttrNamesAndTypes() const;

Operator &SetAttr(const string &name, int64_t attr_value);
Operator &SetAttr(const string &name, int32_t attr_value);
Operator &SetAttr(const string &name, uint32_t attr_value);
graphStatus GetAttr(const string &name, int64_t &attr_value) const;
graphStatus GetAttr(const string &name, int32_t &attr_value) const;
graphStatus GetAttr(const string &name, uint32_t &attr_value) const;
Operator &SetAttr(const string &name, const std::vector<int64_t> &attr_value);
Operator &SetAttr(const string &name, const std::vector<int32_t> &attr_value);
Operator &SetAttr(const string &name, const std::vector<uint32_t> &attr_value);
Operator &SetAttr(const string &name, std::initializer_list<int64_t> &&attr_value);
graphStatus GetAttr(const string &name, std::vector<int64_t> &attr_value) const;
graphStatus GetAttr(const string &name, std::vector<int32_t> &attr_value) const;
graphStatus GetAttr(const string &name, std::vector<uint32_t> &attr_value) const;

Operator &SetAttr(const string &name, float attr_value);
graphStatus GetAttr(const string &name, float &attr_value) const;
Operator &SetAttr(const string &name, const std::vector<float> &attr_value);
graphStatus GetAttr(const string &name, std::vector<float> &attr_value) const;
Operator &SetAttr(const string &name, AttrValue &&attr_value);
graphStatus GetAttr(const string &name, AttrValue &attr_value) const;

Operator &SetAttr(const string &name, const string &attr_value);
graphStatus GetAttr(const string &name, string &attr_value) const;
Operator &SetAttr(const string &name, const std::vector<string> &attr_value);
graphStatus GetAttr(const string &name, std::vector<string> &attr_value) const;

Operator &SetAttr(const string &name, bool attr_value);
graphStatus GetAttr(const string &name, bool &attr_value) const;
Operator &SetAttr(const string &name, const std::vector<bool> &attr_value);
graphStatus GetAttr(const string &name, std::vector<bool> &attr_value) const;

Operator &SetAttr(const string &name, const Tensor &attr_value);
graphStatus GetAttr(const string &name, Tensor &attr_value) const;
Operator &SetAttr(const string &name, const std::vector<Tensor> &attr_value);
graphStatus GetAttr(const string &name, std::vector<Tensor> &attr_value) const;

// Bytes type
Operator &SetAttr(const string &name, const OpBytes &attr_value);
// Bytes type
graphStatus GetAttr(const string &name, OpBytes &attr_value) const;

Operator &SetAttr(const string &name, const std::vector<std::vector<int64_t>> &attr_value);
graphStatus GetAttr(const string &name, std::vector<std::vector<int64_t>> &attr_value) const;

Operator &SetAttr(const string &name, const std::vector<ge::DataType> &attr_value);
graphStatus GetAttr(const string &name, std::vector<ge::DataType> &attr_value) const;

Operator &SetAttr(const string &name, const ge::DataType &attr_value);
graphStatus GetAttr(const string &name, ge::DataType &attr_value) const;

// func type
Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value);
graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const;
Operator &SetAttr(const string &name, const std::vector<ge::NamedAttrs> &attr_value);
graphStatus GetAttr(const string &name, std::vector<ge::NamedAttrs> &attr_value) const;

void BreakConnect() const;

size_t GetSubgraphNamesCount() const;
std::vector<std::string> GetSubgraphNames() const;
SubgraphBuilder GetSubgraphBuilder(const string &name) const;
Graph GetSubgraph(const string &name) const;
SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const;
Graph GetDynamicSubgraph(const string &name, uint32_t index) const;

protected:
void AttrRegister(const string &name, float attr_value);
void AttrRegister(const string &name, const std::vector<float> &attr_value);
void AttrRegister(const string &name, int64_t attr_value);
void AttrRegister(const string &name, const std::vector<int64_t> &attr_value);
void AttrRegister(const string &name, const string &attr_value);
void AttrRegister(const string &name, const std::vector<string> &attr_value);
void AttrRegister(const string &name, bool attr_value);
void AttrRegister(const string &name, const std::vector<bool> &attr_value);
void AttrRegister(const string &name, const Tensor &attr_value);
void AttrRegister(const string &name, const std::vector<Tensor> &attr_value);
void AttrRegister(const string &name, const OpBytes &attr_value);
void AttrRegister(const string &name, const std::vector<std::vector<int64_t>> &attr_value);
void AttrRegister(const string &name, const std::vector<ge::DataType> &attr_value);
void AttrRegister(const string &name, const ge::DataType &attr_value);
void AttrRegister(const string &name, const ge::NamedAttrs &attr_value);
void AttrRegister(const string &name, const std::vector<ge::NamedAttrs> &attr_value);

explicit Operator(OperatorImplPtr &&op_impl);

void InputRegister(const string &name);

void OptionalInputRegister(const string &name);

void InferFuncRegister(const std::function<graphStatus(Operator &)> &func);

void VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func);

void InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func);

void OutputRegister(const string &name);

void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true);

void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index);

void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true);

void RequiredAttrRegister(const string &name);

graphStatus VerifyAll(); // lint !e148

// Only has one output index = 0
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt);

Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt,
const string &name); // lint !e148

void SubgraphRegister(const string &ir_name, bool dynamic);
void SubgraphCountRegister(const string &ir_name, uint32_t count);
void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder);

private:
Operator &SetInput(const string &dst_name, const OutHandler &out_handler); // lint !e148

OutHandler GetOutput(const string &name) const;

OutHandler GetOutput(uint32_t index) const;

OperatorImplPtr GetOperatorImplPtr() const;

OperatorImplPtr operator_impl_{nullptr};

graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const;

std::shared_ptr<const Node> GetNode() const;
};
/*lint +e148*/
} // namespace ge

#endif // INC_EXTERNAL_GRAPH_OPERATOR_H_

+ 0
- 68
metadef/inc/external/graph/operator_factory.h View File

@@ -1,68 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_
#define INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_

#include <map>
#include <memory>
#include <string>
#include <vector>

#include "./operator.h"
#include "./ge_error_codes.h"

namespace ge {
using OpCreator = std::function<Operator(const std::string &)>;
using InferShapeFunc = std::function<graphStatus(Operator &)>;
using InferFormatFunc = std::function<graphStatus(Operator &)>;
using VerifyFunc = std::function<graphStatus(Operator &)>;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactory {
public:
static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type);

static graphStatus GetOpsTypeList(std::vector<std::string> &all_ops);

static bool IsExistOp(const string &operator_type);
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorCreatorRegister {
public:
OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator);
~OperatorCreatorRegister() = default;
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferShapeFuncRegister {
public:
InferShapeFuncRegister(const std::string &operator_type, const InferShapeFunc &infer_shape_func);
~InferShapeFuncRegister() = default;
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferFormatFuncRegister {
public:
InferFormatFuncRegister(const std::string &operator_type, const InferFormatFunc &infer_format_func);
~InferFormatFuncRegister() = default;
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY VerifyFuncRegister {
public:
VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func);
~VerifyFuncRegister() = default;
};
} // namespace ge

#endif // INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_

+ 0
- 376
metadef/inc/external/graph/operator_reg.h View File

@@ -1,376 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_
#define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_

#include <functional>
#include <memory>
#include <string>
#include <vector>

#include "graph/operator.h"
#include "graph/operator_factory.h"
#include "graph/tensor.h"
#include "graph/types.h"
#include "graph/graph.h"

namespace ge {
using std::function;
using std::string;
using std::vector;

class OpReg {
public:
OpReg &N() { return *this; }

OpReg &ATTR() { return *this; }

OpReg &REQUIRED_ATTR() { return *this; }

OpReg &INPUT() { return *this; }

OpReg &OPTIONAL_INPUT() { return *this; }

OpReg &OUTPUT() { return *this; }

OpReg &GRAPH() { return *this; }

OpReg &DYNAMIC_GRAPH() { return *this; }

OpReg &INFER_SHAPE_AND_TYPE() { return *this; }
};

#define REG_OP(x) \
namespace op { \
class x : public Operator { \
typedef x _THIS_TYPE; \
\
public: \
explicit x(const string &name) : Operator(name, #x) { __##x(); } \
x() : Operator(#x) { __##x(); } \
\
private: \
void __##x() { \
OpReg()

#define ATTR(x, Type, ...) \
N(); \
__attr_##x(); \
} \
\
public: \
static const string name_attr_##x() { return #x; } \
Op##Type get_attr_##x() const { \
Op##Type ret = __VA_ARGS__; \
if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
return ret; \
} \
return ret; \
} \
_THIS_TYPE &set_attr_##x(const Op##Type &v) { \
Operator::SetAttr(#x, v); \
return *this; \
} \
_THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \
\
private: \
void __attr_##x() { \
Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \
string attr_name(#x); \
(void)OpReg()

#define REQUIRED_ATTR(x, Type) \
N(); \
__required_attr_##x(); \
} \
\
public: \
static const string name_attr_##x() { return #x; } \
Op##Type get_attr_##x() const { \
Op##Type ret; \
if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
return ret; \
} \
return ret; \
} \
_THIS_TYPE &set_attr_##x(const Op##Type &v) { \
Operator::SetAttr(#x, v); \
return *this; \
} \
_THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \
\
private: \
void __required_attr_##x() { \
Operator::RequiredAttrRegister(#x); \
string attr_name(#x); \
(void)OpReg()

#define INPUT(x, t) \
N(); \
__input_##x(); \
} \
\
public: \
static const string name_in_##x() { return #x; } \
_THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \
Operator::SetInput(#x, v, srcName); \
return *this; \
} \
_THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \
Operator::SetInput(#x, v, index); \
return *this; \
} \
_THIS_TYPE &set_input_##x(Operator &v) { \
Operator::SetInput(#x, v); \
return *this; \
} \
TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \
graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \
return Operator::UpdateInputDesc(#x, tensorDesc); \
} \
\
private: \
void __input_##x() { \
Operator::InputRegister(#x); \
(void)OpReg()

#define OPTIONAL_INPUT(x, t) \
N(); \
__optional_input_##x(); \
} \
\
public: \
static const string name_in_##x() { return #x; } \
_THIS_TYPE &set_input_##x(Operator &v) { \
Operator::SetInput(#x, v); \
return *this; \
} \
_THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \
Operator::SetInput(#x, v, srcName); \
return *this; \
} \
_THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \
Operator::SetInput(#x, v, index); \
return *this; \
} \
TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \
graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \
return Operator::UpdateInputDesc(#x, tensorDesc); \
} \
\
private: \
void __optional_input_##x() { \
Operator::OptionalInputRegister(#x); \
(void)OpReg()

#define OUTPUT(x, t) \
N(); \
__out_##x(); \
} \
\
public: \
static const string name_out_##x() { return #x; } \
TensorDesc get_output_desc_##x() const { return Operator::GetOutputDesc(#x); } \
graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \
return Operator::UpdateOutputDesc(#x, tensorDesc); \
} \
\
private: \
void __out_##x() { \
Operator::OutputRegister(#x); \
(void)OpReg()

#define DYNAMIC_INPUT(x, t) \
N(); \
__dy_input_##x(); \
} \
\
public: \
_THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \
Operator::DynamicInputRegister(#x, num, isPushBack); \
return *this; \
} \
_THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \
Operator::DynamicInputRegisterByIndex(#x, num, index); \
return *this; \
} \
TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { return Operator::GetDynamicInputDesc(#x, index); } \
graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \
return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \
} \
_THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \
Operator::SetInput(#x, dstIndex, v); \
return *this; \
} \
_THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \
Operator::SetInput(#x, dstIndex, v, srcName); \
return *this; \
} \
\
private: \
void __dy_input_##x() { \
Operator::DynamicInputRegister(#x, 0, true); \
(void)OpReg()

#define DYNAMIC_OUTPUT(x, t) \
N(); \
__dy_output_##x(); \
} \
\
public: \
_THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \
Operator::DynamicOutputRegister(#x, num, isPushBack); \
return *this; \
} \
TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { return Operator::GetDynamicOutputDesc(#x, index); } \
graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \
return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \
} \
\
private: \
void __dy_output_##x() { \
Operator::DynamicOutputRegister(#x, 0, true); \
(void)OpReg()

#define GRAPH(x) \
N(); \
__graph_##x(); \
} \
\
public: \
static const string name_graph_##x() { return #x; } \
SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \
_THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \
Operator::SetSubgraphBuilder(#x, 0, v); \
return *this; \
} \
Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \
\
private: \
void __graph_##x() { \
Operator::SubgraphRegister(#x, false); \
Operator::SubgraphCountRegister(#x, 1); \
(void)OpReg()

#define DYNAMIC_GRAPH(x) \
N(); \
__graph_##x(); \
} \
\
public: \
static const string name_graph_##x() { return #x; } \
_THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \
Operator::SubgraphCountRegister(#x, num); \
return *this; \
} \
SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \
return Operator::GetDynamicSubgraphBuilder(#x, index); \
} \
Graph get_dynamic_subgraph_##x(uint32_t index) const { return Operator::GetDynamicSubgraph(#x, index); } \
_THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \
Operator::SetSubgraphBuilder(#x, index, v); \
return *this; \
} \
\
private: \
void __graph_##x() { \
Operator::SubgraphRegister(#x, true); \
(void)OpReg()

#define PASTE(g_register, y) g_register##y
#define __OP_END_IMPL__(x, y) \
N(); \
} \
static_assert( \
std::is_same<x, _THIS_TYPE>::value, \
"The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \
} \
; \
static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \
}
#define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__)

// Specialized shape inferencer macro

#define IMPLEMT_INFERFUNC(op_name, func_name) \
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op)

#define IMPLEMT_COMMON_INFERFUNC(func_name) \
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op)

#define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op)

// Specialized verifier macro

#define IMPLEMT_VERIFIER(op_name, func_name) \
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op)

#define INFER_VERIFY_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); }

#define COMMON_INFER_VERIFY_FUNC(x) [&](Operator &v) { return x(v); }

#define INFER_FORMAT_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); }

#define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x)

#define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x)
// Infer format func register
#define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \
static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x)

// Shape inferencer & verifier register macro

#define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__)

#define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC(x), __COUNTER__)

#define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__)

// Infer format func reg
#define INFER_FORMAT_FUNC_REG(op_name, x) \
__INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__)

// Common shape inferencer

#define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \
[](Operator op) -> graphStatus { \
auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \
auto x_type = op.GetInputDesc(in_name).GetDataType(); \
TensorDesc op_output_desc = op.GetOutputDesc(out_name); \
op_output_desc.SetShape(ge::Shape(x_shape)); \
op_output_desc.SetOriginShape(ge::Shape(x_shape)); \
op_output_desc.SetDataType(x_type); \
return op.UpdateOutputDesc(out_name, op_output_desc); \
}

graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape,
const function<vector<int64_t>()> &get_in2_shape,
const function<void(const vector<int64_t> &y_shape)> &set_out_shape);

#define BROADCAST_INFER(in1_name, in2_name, out_name) \
[](Operator op) -> graphStatus { \
return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \
[&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \
[&](const vector<int64_t> &y_shape) { \
TensorDesc op_output_desc = op.GetOutputDesc(out_name); \
op_output_desc.SetShape(ge::Shape(y_shape)); \
(void)op.UpdateOutputDesc(out_name, op_output_desc); \
}); \
}
} // namespace ge
#endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_

+ 0
- 131
metadef/inc/external/graph/tensor.h View File

@@ -1,131 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_TENSOR_H_
#define INC_EXTERNAL_GRAPH_TENSOR_H_

#include <atomic>
#include <memory>
#include <string>
#include <vector>
#include <utility>

#include "./ge_error_codes.h"
#include "./types.h"

namespace ge {
class ShapeImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape {
public:
Shape();
~Shape() = default;
explicit Shape(const std::vector<int64_t> &dims);

size_t GetDimNum() const;
// If the idx is invalid, return 0
int64_t GetDim(size_t idx) const;
graphStatus SetDim(size_t idx, int64_t value);
std::vector<int64_t> GetDims() const;
int64_t GetShapeSize() const;

private:
std::shared_ptr<ShapeImpl> impl_;
};

class TensorDescImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc {
public:
TensorDesc();
~TensorDesc() = default;
explicit TensorDesc(Shape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);
// Copy
TensorDesc(const TensorDesc &desc);
// Move
TensorDesc(TensorDesc &&desc);
// Copy
TensorDesc &operator=(const TensorDesc &desc);
// Move
TensorDesc &operator=(TensorDesc &&desc);

void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);
Shape GetShape() const;
void SetShape(const Shape &shape);
// set shape with -2, it stand for unknown shape
graphStatus SetUnknownDimNumShape();
// for unknown shape
graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range);
graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const;

Format GetFormat() const;
void SetFormat(Format format);

Shape GetOriginShape() const;
void SetOriginShape(const Shape &originShape);

Format GetOriginFormat() const;
void SetOriginFormat(Format originFormat);

DataType GetDataType() const;
void SetDataType(DataType dt);

std::string GetName() const;
void SetName(const std::string &name);

// Attr acess
void SetSize(int64_t size);
int64_t GetSize() const;

int64_t GetRealDimCnt() const;
void SetRealDimCnt(const int64_t realDimCnt);

private:
std::shared_ptr<TensorDescImpl> impl;
};

class TensorImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor {
public:
Tensor();
~Tensor() = default;
explicit Tensor(const TensorDesc &tensorDesc);
Tensor(const TensorDesc &tensorDesc, const std::vector<uint8_t> &data);
Tensor(const TensorDesc &tensorDesc, const uint8_t *data, size_t size);
Tensor(TensorDesc &&tensorDesc, std::vector<uint8_t> &&data);

TensorDesc GetTensorDesc() const;
graphStatus SetTensorDesc(const TensorDesc &tensorDesc);

const uint8_t *GetData() const;
uint8_t *GetData();
size_t GetSize() const;

graphStatus SetData(std::vector<uint8_t> &&data);
graphStatus SetData(const std::vector<uint8_t> &data);
graphStatus SetData(const uint8_t *data, size_t size);
graphStatus SetData(const std::string &data);
graphStatus SetData(const std::vector<std::string> &data);
graphStatus IsValid();

Tensor Clone() const;

private:
std::shared_ptr<TensorImpl> impl;
friend class TensorAdapter;
};
} // namespace ge
/*lint +e148*/

#endif // INC_EXTERNAL_GRAPH_TENSOR_H_

+ 0
- 240
metadef/inc/external/graph/types.h View File

@@ -1,240 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_TYPES_H_
#define INC_EXTERNAL_GRAPH_TYPES_H_

#include <atomic>
#include <memory>
#include <vector>

namespace ge {
static const int64_t UNKNOWN_DIM = -1;
static const int64_t UNKNOWN_DIM_NUM = -2;
static const std::vector<int64_t> UNKNOWN_SHAPE = {-1};
static const std::vector<int64_t> UNKNOWN_RANK = {-2};

#ifdef HOST_VISIBILITY
#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_HOST_VISIBILITY
#endif
#ifdef DEV_VISIBILITY
#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_DEV_VISIBILITY
#endif

enum DataType {
DT_FLOAT = 0, // float type
DT_FLOAT16 = 1, // fp16 type
DT_INT8 = 2, // int8 type
DT_INT16 = 6, // int16 type
DT_UINT16 = 7, // uint16 type
DT_UINT8 = 4, // uint8 type
DT_INT32 = 3, //
DT_INT64 = 9, // int64 type
DT_UINT32 = 8, // unsigned int32
DT_UINT64 = 10, // unsigned int64
DT_BOOL = 12, // bool type
DT_DOUBLE = 11, // double type
DT_STRING = 13, // string type
DT_DUAL_SUB_INT8 = 14, // dual output int8 type
DT_DUAL_SUB_UINT8 = 15, // dual output uint8 type
DT_COMPLEX64 = 16, // complex64 type
DT_COMPLEX128 = 17, // complex128 type
DT_QINT8 = 18, // qint8 type
DT_QINT16 = 19, // qint16 type
DT_QINT32 = 20, // qint32 type
DT_QUINT8 = 21, // quint8 type
DT_QUINT16 = 22, // quint16 type
DT_RESOURCE = 23, // resource type
DT_STRING_REF = 24, // string ref type
DT_DUAL = 25, // dual output type
DT_UNDEFINED // Used to indicate a DataType field has not been set.
};

inline int GetSizeByDataType(DataType data_type) {
static int data_type_size[DT_UNDEFINED] = {
4, // DT_FLOAT = 0, float type
2, // DT_FLOAT16 = 1, fp16 type
1, // DT_INT8 = 2, int8 type
4, // DT_INT32 = 3,
1, // DT_UINT8 = 4, uint8 type
-1,
2, // DT_INT16 = 6, int16 type
2, // DT_UINT16 = 7, uint16 type
4, // DT_UINT32 = 8, unsigned int32
8, // DT_INT64 = 9, int64 type
8, // DT_UINT64 = 10, unsigned int64
8, // DT_DOUBLE = 11, double type
1, // DT_BOOL = 12, bool type
-1, // DT_STRING = 13, string type
1, // DT_DUAL_SUB_INT8 = 14, dual output int8 type
1, // DT_DUAL_SUB_UINT8 = 15, dual output uint8 type
8, // DT_COMPLEX64 = 16, complex64 type
16, // DT_COMPLEX128 = 17, complex128 type
1, // DT_QINT8 = 18, qint8 type
2, // DT_QINT16 = 19, qint16 type
4, // DT_QINT32 = 20, qint32 type
1, // DT_QUINT8 = 21, quint8 type
2, // DT_QUINT16 = 22, quint16 type
-1, // DT_RESOURCE = 23, resource type
-1, // DT_STRING_REF = 24, string ref type
5, // DT_DUAL = 25, dual output type (float + int8)
// DT_UNDEFINED Used to indicate a DataType field has not been set.
};
if (data_type >= DT_UNDEFINED) {
return -1;
}
return data_type_size[data_type];
}

enum Format {
FORMAT_NCHW = 0, // NCHW
FORMAT_NHWC, // NHWC
FORMAT_ND, // Nd Tensor
FORMAT_NC1HWC0, // NC1HWC0
FORMAT_FRACTAL_Z, // FRACTAL_Z
FORMAT_NC1C0HWPAD,
FORMAT_NHWC1C0,
FORMAT_FSR_NCHW,
FORMAT_FRACTAL_DECONV,
FORMAT_C1HWNC0,
FORMAT_FRACTAL_DECONV_TRANSPOSE,
FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS,
FORMAT_NC1HWC0_C04, // NC1HWC0, C0 =4
FORMAT_FRACTAL_Z_C04, // FRACZ, C0 =4
FORMAT_CHWN,
FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS,
FORMAT_HWCN,
FORMAT_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format
FORMAT_BN_WEIGHT,
FORMAT_FILTER_HWCK, // filter input tensor format
FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20,
FORMAT_HASHTABLE_LOOKUP_KEYS,
FORMAT_HASHTABLE_LOOKUP_VALUE,
FORMAT_HASHTABLE_LOOKUP_OUTPUT,
FORMAT_HASHTABLE_LOOKUP_HITS = 24,
FORMAT_C1HWNCoC0,
FORMAT_MD,
FORMAT_NDHWC,
FORMAT_FRACTAL_ZZ,
FORMAT_FRACTAL_NZ,
FORMAT_NCDHW,
FORMAT_DHWCN, // 3D filter input tensor format
FORMAT_NDC1HWC0,
FORMAT_FRACTAL_Z_3D,
FORMAT_CN,
FORMAT_NC,
FORMAT_DHWNC,
FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format
FORMAT_FRACTAL_ZN_LSTM,
FORMAT_FRACTAL_Z_G,
FORMAT_RESERVED,
FORMAT_ALL,
FORMAT_NULL
};

// for unknown shape op type
enum UnknowShapeOpType {
DEPEND_IN_SHAPE = 1, // op out shape get by input shape
DEPEND_CONST_VALUE = 2, // op out shape get by const op value
DEPEND_SHAPE_RANGE = 3, // op out shape get by range
DEPEND_COMPUTE = 4 // op out shape get by totally computing
};

struct TensorDescInfo {
Format format_ = FORMAT_RESERVED; // tbe op register support format
DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype
};

enum DeviceType {
NPU = 0,
CPU = 1,
};

class TensorTypeImpl;
struct TensorType {
explicit TensorType(DataType dt);

TensorType(const std::initializer_list<DataType> &types);

static TensorType ALL() {
return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16,
DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16,
DT_QUINT8, DT_RESOURCE, DT_STRING, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8};
}

static TensorType QuantifiedType() { return TensorType{DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, DT_QUINT8}; }

static TensorType OrdinaryType() {
return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16,
DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8};
}

static TensorType BasicType() {
return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16,
DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8,
DT_QUINT16, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8};
}

static TensorType NumberType() {
return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64,
DT_INT8, DT_QINT32, DT_QINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8};
}

static TensorType RealNumberType() {
return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64,
DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8};
}

static TensorType ComplexDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64}; }

static TensorType IntegerDataType() {
return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8};
}

static TensorType SignedDataType() { return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8}; }

static TensorType UnsignedDataType() { return TensorType{DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; }

static TensorType FloatingDataType() { return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; }

static TensorType IndexNumberType() { return TensorType{DT_INT32, DT_INT64}; }

static TensorType UnaryDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; }

static TensorType FLOAT() { return TensorType{DT_FLOAT, DT_FLOAT16}; }

std::shared_ptr<TensorTypeImpl> tensor_type_impl_;
};
} // namespace ge

namespace domi {
enum class ImplyType : unsigned int {
BUILDIN = 0, // Built in operator, normally executed by OME
TVM, // Compile to TVM bin file for execution
CUSTOM, // User defined calculation logic, executed by CPU
AI_CPU, // AICPU
CCE, // Cce
GELOCAL, // GE local, do node need execute by device
HCCL, // Hccl
INVALID = 0xFFFFFFFF,
};
} // namespace domi

#endif // INC_EXTERNAL_GRAPH_TYPES_H_

+ 0
- 163
metadef/inc/external/register/register.h View File

@@ -1,163 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_REGISTER_REGISTER_H_
#define INC_EXTERNAL_REGISTER_REGISTER_H_

#include <functional>
#include <initializer_list>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <unordered_map>
#include <vector>

#include "graph/operator.h"
#include "register/register_error_codes.h"
#include "register/register_fmk_types.h"
#include "register/register_types.h"

using std::make_shared;
using std::map;
using std::pair;
using std::string;
using std::to_string;
using std::unique_ptr;
using std::vector;

/*lint -e148*/
namespace ge {
class Operator;
class TensorDesc;
class Tensor;
class TBEPluginManager;
} // namespace ge

namespace google {
namespace protobuf {
class Message;
}
} // namespace google

namespace domi {
const int64_t kMaxNameLength = 1048576; // 1M

enum DynamicType { kInvalid = 0, kInput = 1, kOutput = 2 };
struct DynamicInputOutputInfo {
DynamicType type; // input/output
const char *port_name;
int64_t port_name_len;
const char *attr_name;
int64_t attr_name_len;
DynamicInputOutputInfo()
: type(kInvalid), port_name(nullptr), port_name_len(0), attr_name(nullptr), attr_name_len(0) {}
DynamicInputOutputInfo(DynamicType type, const char *port_name, int64_t port_name_len, const char *attr_name,
int64_t attr_name_len)
: type(type),
port_name(port_name),
port_name_len(port_name_len),
attr_name(attr_name),
attr_name_len(attr_name_len) {}
};
Status AutoMappingByOpFn(const ge::Operator &op_src, ge::Operator &op);
Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, ge::Operator &op,
const vector<DynamicInputOutputInfo> &dynamic_name_attr_value);
Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op);
Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op,
std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value,
int in_pos = -1, int out_pos = -1);
Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function<int(int data_index)> &input,
const std::function<int(int netoutput_index)> &output);
Status AutoMappingSubgraphIndex(const ge::Graph &graph,
const std::function<Status(int data_index, int &parent_input_index)> &input,
const std::function<Status(int netoutput_index, int &parent_output_index)> &output);
using google::protobuf::Message;
class OpRegistrationDataImpl;

using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>;
using ParseParamByOpFunc = std::function<domi::Status(const ge::Operator &, ge::Operator &)>;
using FusionParseParamFunc =
std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>;
using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>;
using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>;

class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {
public:
OpRegistrationData(const std::string &om_optype);

~OpRegistrationData();

OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type);

OpRegistrationData &OriginOpType(const std::initializer_list<std::string> &ori_optype_list);

OpRegistrationData &OriginOpType(const std::string &ori_optype);

OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn);

OpRegistrationData &ParseParamsByOperatorFn(const ParseParamByOpFunc &parse_param_by_op_fn);

OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn);

OpRegistrationData &FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn);

OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn);

OpRegistrationData &ImplyType(const domi::ImplyType &imply_type);

OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue);

OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type);

OpRegistrationData &InputReorderVector(const vector<int> &input_order);

domi::ImplyType GetImplyType() const;
std::string GetOmOptype() const;
std::set<std::string> GetOriginOpTypeSet() const;
domi::FrameworkType GetFrameworkType() const;
ParseParamFunc GetParseParamFn() const;
ParseParamByOpFunc GetParseParamByOperatorFn() const;
FusionParseParamFunc GetFusionParseParamFn() const;
FusionParseParamByOpFunc GetFusionParseParamByOpFn() const;
ParseSubgraphFunc GetParseSubgraphPostFn() const;

private:
std::shared_ptr<OpRegistrationDataImpl> impl_;
friend class OpRegistry;
friend class OpRegistrationTbe;
friend class ge::TBEPluginManager;
};

class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver {
public:
OpReceiver(OpRegistrationData &reg_data);
~OpReceiver() {}
};

#define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name)
#define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name)
#define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \
static OpReceiver register_op##ctr __attribute__((unused)) = OpRegistrationData(name)
} // namespace domi

namespace ge {
using OpRegistrationData = domi::OpRegistrationData;
using OpReceiver = domi::OpReceiver;
} // namespace ge
/*lint +e148*/
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_

+ 0
- 39
metadef/inc/external/register/register_error_codes.h View File

@@ -1,39 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_
#define INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_

#define SYSID_FWK 3 // Subsystem ID
#define MODID_COMMON 0 // Common module ID

#define DECLARE_ERRORNO(sysid, modid, name, value) \
const domi::Status name = \
((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value));

#define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value)

namespace domi {
using Status = uint32_t;

// General error code
DECLARE_ERRORNO(0, 0, SUCCESS, 0);
DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF);
DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649
DECLARE_ERRORNO(SYSID_FWK, 1, SCOPE_NOT_CHANGED, 201);
} // namespace domi

#endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_

+ 0
- 37
metadef/inc/external/register/register_fmk_types.h View File

@@ -1,37 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_
#define INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_

#include <string>

namespace domi {
///
/// @ingroup domi_omg
/// @brief AI framework types
///
enum FrameworkType {
CAFFE = 0,
MINDSPORE = 1,
TENSORFLOW = 3,
ANDROID_NN,
ONNX,
FRAMEWORK_RESERVED,
};
} // namespace domi

#endif // INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_

+ 0
- 59
metadef/inc/external/register/register_types.h View File

@@ -1,59 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_
#define INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_

namespace domi {
#ifdef HOST_VISIBILITY
#define FMK_FUNC_HOST_VISIBILITY __attribute__((visibility("default")))
#else
#define FMK_FUNC_HOST_VISIBILITY
#endif
#ifdef DEV_VISIBILITY
#define FMK_FUNC_DEV_VISIBILITY __attribute__((visibility("default")))
#else
#define FMK_FUNC_DEV_VISIBILITY
#endif

/// CCE defined constant

///
/// @ingroup domi
/// @brief original tensor type
///
typedef enum tagDomiTensorFormat {
DOMI_TENSOR_NCHW = 0, // < NCHW
DOMI_TENSOR_NHWC, // < NHWC
DOMI_TENSOR_ND, // < Nd Tensor
DOMI_TENSOR_NC1HWC0, // < NC1HWC0
DOMI_TENSOR_FRACTAL_Z, // < FRACTAL_Z
DOMI_TENSOR_NC1C0HWPAD,
DOMI_TENSOR_NHWC1C0,
DOMI_TENSOR_FSR_NCHW,
DOMI_TENSOR_FRACTAL_DECONV,
DOMI_TENSOR_BN_WEIGHT,
DOMI_TENSOR_CHWN, // Android NN Depth CONV
DOMI_TENSOR_FILTER_HWCK, // filter input tensor format
DOMI_TENSOR_NDHWC,
DOMI_TENSOR_NCDHW,
DOMI_TENSOR_DHWCN, // 3D filter input tensor format
DOMI_TENSOR_DHWNC,
DOMI_TENSOR_RESERVED
} domiTensorFormat_t;
} // namespace domi

#endif // INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_

+ 0
- 334
metadef/inc/external/register/scope/scope_fusion_pass_register.h View File

@@ -1,334 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_
#define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_

#include <memory>
#include <string>
#include <vector>
#include <map>
#include <unordered_map>
#include "ge/ge_api_error_codes.h"
#include "register/register_error_codes.h"
#include "register/register_types.h"
#include "graph/operator.h"

#define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \
do { \
if (!(cond)) { \
if ((fusion_rlt) != nullptr) { \
(fusion_rlt)->SetType(ge::kScopeInvalidType); \
} \
return; \
} \
} while (0)

namespace domi {
class TensorFlowModelParser;
} // namespace domi
namespace ge {
const int32_t kFusionDisableIndex = 99999;
const char *const kScopeToMultiNodes = "ScopeToMultiNodes";
const char *const kScopeInvalidType = "ScopeInvalidType";
const char *const kInputFromFusionScope = "InputFromFusionScope";
const char *const kOutputToFusionScope = "OutputToFusionScope";
class ScopePattern;
using ScopeFusionPatterns = std::vector<std::vector<ScopePattern *>>;

class ScopePassManager;

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope {
public:
Scope();
Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr);
~Scope();

const std::string &Name() const;
const std::string &SubType() const;
const std::unordered_map<std::string, ge::OperatorPtr> &AllNodesMap() const;
Scope *GetSubScope(const std::string &scope_name) const;
const std::string LastName() const;
const std::vector<Scope *> &GetAllSubScopes() const;
const Scope *GetFatherScope() const;

private:
class ScopeImpl;
std::unique_ptr<ScopeImpl> impl_;
friend class ScopeBasePass;
friend class ScopeTree;
friend class NodeOpTypeFeature;
friend class NodeAttrFeature;
friend class ScopeFeature;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult {
public:
FusionScopesResult();
Status Init();
~FusionScopesResult();
void SetName(const std::string &name);
void SetType(const std::string &type);
void SetDescription(const std::string &description);
const std::string &Name() const;
const std::vector<ge::OperatorPtr> &Nodes() const;
void InsertInputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map);
void InsertOutputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map);

class InnerNodeInfo {
public:
explicit InnerNodeInfo(const std::string &fusion_node_name);
InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type);
InnerNodeInfo(InnerNodeInfo &&other) noexcept;
InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept;
InnerNodeInfo(const InnerNodeInfo &) = delete;
InnerNodeInfo &operator=(const InnerNodeInfo &) = delete;
~InnerNodeInfo();
InnerNodeInfo &SetName(const std::string &name);
InnerNodeInfo &SetType(const std::string &type);
InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx);
InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx);
ge::graphStatus BuildInnerNode();
ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format);
ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format);
ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format);
ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format);
ge::Operator *MutableOperator();

std::string GetName() const;
std::string GetType() const;
std::vector<std::pair<std::string, int32_t>> GetInputs() const;
std::vector<std::pair<std::string, int32_t>> GetOutputs() const;

private:
class InnerNodeInfoImpl;
std::unique_ptr<InnerNodeInfoImpl> impl_;
};

InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type);
InnerNodeInfo *MutableRecentInnerNode();
InnerNodeInfo *MutableInnerNode(uint32_t index);
ge::graphStatus CheckInnerNodesInfo();

private:
class FusionScopesResultImpl;
std::unique_ptr<FusionScopesResultImpl> impl_;
friend class ScopeGraph;
friend class ScopeBasePass;
friend class TensorFlowModelParser;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree {
public:
ScopeTree();
Status Init();
ScopeTree(const ScopeTree &scopetree) = delete;
ScopeTree &operator=(const ScopeTree &scopetree) = delete;
~ScopeTree();

const std::vector<Scope *> &GetAllScopes() const;

private:
class ScopeTreeImpl;
std::unique_ptr<ScopeTreeImpl> impl_;
friend class ScopeGraph;
friend class ScopeBasePass;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph {
public:
ScopeGraph();
Status Init();
ScopeGraph(const ScopeGraph &scope_graph) = delete;
ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete;
~ScopeGraph();

const ScopeTree *GetScopeTree() const;
const std::unordered_map<std::string, ge::OperatorPtr> &GetNodesMap() const;

private:
class ScopeGraphImpl;
std::unique_ptr<ScopeGraphImpl> impl_;
friend class ScopePassManager;
friend class ScopeBasePass;
friend class TensorFlowModelParser;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue {
public:
ScopeAttrValue();
ScopeAttrValue(ScopeAttrValue const &attr_value);
ScopeAttrValue &operator=(ScopeAttrValue const &attr_value);
~ScopeAttrValue();

void SetIntValue(int64_t value);
void SetFloatValue(float value);
void SetStringValue(std::string value);
void SetBoolValue(bool value);

private:
class ScopeAttrValueImpl;
std::unique_ptr<ScopeAttrValueImpl> impl_;
friend class NodeAttrFeature;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature {
public:
virtual bool Match(const Scope *scope) = 0;
virtual ~ScopeBaseFeature(){};
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature {
public:
NodeOpTypeFeature(std::string nodeType, int num, int step = 0);
NodeOpTypeFeature(NodeOpTypeFeature const &feature);
NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature);
~NodeOpTypeFeature();
bool Match(const Scope *scope) override;

private:
class NodeOpTypeFeatureImpl;
std::unique_ptr<NodeOpTypeFeatureImpl> impl_;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature {
public:
NodeAttrFeature(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue &attr_value);
NodeAttrFeature(NodeAttrFeature const &feature);
NodeAttrFeature &operator=(NodeAttrFeature const &feature);
~NodeAttrFeature();
bool Match(const Scope *scope) override;

private:
class NodeAttrFeatureImpl;
std::unique_ptr<NodeAttrFeatureImpl> impl_;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature {
public:
ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", std::string sub_scope_mask = "",
int step = 0);
ScopeFeature(ScopeFeature const &feature);
ScopeFeature &operator=(ScopeFeature const &feature);
~ScopeFeature();
bool Match(const Scope *scope) override;

private:
class ScopeFeatureImpl;
std::unique_ptr<ScopeFeatureImpl> impl_;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern {
public:
ScopePattern();
~ScopePattern();

ScopePattern &SetSubType(const std::string &sub_type);
ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature);
ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature);
ScopePattern &AddScopeFeature(ScopeFeature feature);

private:
class ScopePatternImpl;
std::unique_ptr<ScopePatternImpl> impl_;
friend class ScopeBasePass;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult {
public:
ScopesResult();
ScopesResult(ScopesResult const &result);
ScopesResult &operator=(ScopesResult const &result);
~ScopesResult();

void SetScopes(std::vector<Scope *> &scopes);
void SetNodes(std::vector<ge::OperatorPtr> &nodes);

private:
class ScopesResultImpl;
std::unique_ptr<ScopesResultImpl> impl_;
friend class ScopeBasePass;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass {
public:
ScopeBasePass();
virtual ~ScopeBasePass();

protected:
// Subclasses implement respective fusion strategies and build the Patterns
virtual std::vector<ScopeFusionPatterns> DefinePatterns() = 0;
// Define the name of the scope pass
virtual std::string PassName() = 0;
// Subclasses implement respective multi-scope or operator fusion methods across scopes
virtual Status LastMatchScopesAndOPs(std::shared_ptr<ScopeGraph> &scope_graph,
std::vector<ScopesResult> &results) = 0;
// Subclasses implement their own results and set the input and output of the final fusion operator
virtual void GenerateFusionResult(const std::vector<Scope *> &scopes, FusionScopesResult *fusion_rlt) = 0;

private:
class ScopeBasePassImpl;
std::unique_ptr<ScopeBasePassImpl> impl_;
friend class ge::ScopePassManager;
friend class ScopeBasePassImpl;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry {
public:
using CreateFn = ScopeBasePass *(*)();
~ScopeFusionPassRegistry();

static ScopeFusionPassRegistry &GetInstance() {
static ScopeFusionPassRegistry instance;
return instance;
}

void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general);

private:
ScopeFusionPassRegistry();
class ScopeFusionPassRegistryImpl;
/*lint -e148*/
std::unique_ptr<ScopeFusionPassRegistryImpl> impl_;
friend class TensorFlowModelParser;
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil {
public:
static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value);
static void FreeScopePatterns(ScopeFusionPatterns &patterns);
static void FreeOneBatchPattern(std::vector<ScopePattern *> &one_batch_pattern);
};

class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar {
public:
ScopeFusionPassRegistrar(const char *pass_name, ScopeBasePass *(*create_fn)(), bool is_general);
~ScopeFusionPassRegistrar() {}
};

#define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \
REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, scope_pass, is_general)

#define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \
REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general)

#define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \
static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \
::ge::ScopeFusionPassRegistrar( \
pass_name, []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, is_general)
} // namespace ge

#endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_

+ 0
- 284
metadef/inc/graph/anchor.h View File

@@ -1,284 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_ANCHOR_H_
#define INC_GRAPH_ANCHOR_H_

#include <memory>
#include <string>
#include <vector>
#include "graph/ge_error_codes.h"
#include "graph/range_vistor.h"
#include "graph/types.h"

namespace ge {
enum AnchorStatus {
ANCHOR_SUSPEND = 0, // dat null
ANCHOR_CONST = 1,
ANCHOR_DATA = 2, // Effective
ANCHOR_RESERVED = 3
};
using std::string;
using std::vector;

class Node;

using NodePtr = std::shared_ptr<Node>;

class Edge;

using EdgePtr = std::shared_ptr<Edge>;

class Anchor;

using AnchorPtr = std::shared_ptr<Anchor>;

class DataAnchor;

using DataAnchorPtr = std::shared_ptr<DataAnchor>;

class InDataAnchor;

using InDataAnchorPtr = std::shared_ptr<InDataAnchor>;

class OutDataAnchor;

using OutDataAnchorPtr = std::shared_ptr<OutDataAnchor>;

class ControlAnchor;

using ControlAnchorPtr = std::shared_ptr<ControlAnchor>;

class InControlAnchor;

using InControlAnchorPtr = std::shared_ptr<InControlAnchor>;

class OutControlAnchor;

using OutControlAnchorPtr = std::shared_ptr<OutControlAnchor>;

using ConstAnchor = const Anchor;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable_shared_from_this<Anchor> {
friend class AnchorUtils;

public:
using TYPE = const char *;
template <class T>
using Vistor = RangeVistor<T, std::shared_ptr<ConstAnchor>>;

Anchor(const NodePtr &ownerNode, int idx);

virtual ~Anchor() = default;

protected:
// Whether the two anchor is equal
virtual bool Equal(AnchorPtr anchor) const = 0;
virtual bool IsTypeOf(TYPE type) const;

public:
// Get all peer anchors connected to current anchor
Vistor<AnchorPtr> GetPeerAnchors() const;
// Get peer anchor size
size_t GetPeerAnchorsSize() const;
// Get first peer anchor
AnchorPtr GetFirstPeerAnchor() const;

// Get the anchor belong to which node
NodePtr GetOwnerNode() const;

// Remove all links with the anchor
void UnlinkAll() noexcept;

// Remove link with the given anchor
graphStatus Unlink(const AnchorPtr &peer);

// Replace peer with new peers
graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer);

// Judge if the anchor is linked with the given anchor
bool IsLinkedWith(const AnchorPtr &peer);

// Get anchor index of the node
int GetIdx() const;

// set anchor index of the node
void SetIdx(int index);

protected:
// All peer anchors connected to current anchor
vector<std::weak_ptr<Anchor>> peer_anchors_;
// The owner node of anchor
std::weak_ptr<Node> owner_node_;
// The index of current anchor
int idx_;
template <class T>
static Anchor::TYPE TypeOf() {
static_assert(std::is_base_of<Anchor, T>::value, "T must be a Anchor!");
return __PRETTY_FUNCTION__;
}

public:
template <class T>
static std::shared_ptr<T> DynamicAnchorCast(AnchorPtr anchorPtr) {
static_assert(std::is_base_of<Anchor, T>::value, "T must be a Anchor!");
if (anchorPtr == nullptr || !anchorPtr->IsTypeOf<T>()) {
return nullptr;
}
return std::static_pointer_cast<T>(anchorPtr);
}

template <typename T>
bool IsTypeOf() {
return IsTypeOf(TypeOf<T>());
}
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY DataAnchor : public Anchor {
friend class AnchorUtils;

public:
explicit DataAnchor(const NodePtr &ownerNode, int idx);

virtual ~DataAnchor() = default;

protected:
bool IsTypeOf(TYPE type) const override;

private:
Format format_{FORMAT_ND};
AnchorStatus status_{ANCHOR_SUSPEND};
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataAnchor {
friend class OutDataAnchor;

friend class OutControlAnchor;

public:
explicit InDataAnchor(const NodePtr &ownerNode, int idx);

virtual ~InDataAnchor() = default;

// Get source out data anchor
OutDataAnchorPtr GetPeerOutAnchor() const;

// Build connection from OutDataAnchor to InDataAnchor
graphStatus LinkFrom(const OutDataAnchorPtr &src);

protected:
bool Equal(AnchorPtr anchor) const override;
bool IsTypeOf(TYPE type) const override;
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchor : public DataAnchor {
friend class InDataAnchor;

friend class AnchorUtils;

public:
template <class T>
using Vistor = RangeVistor<T, std::shared_ptr<ConstAnchor>>;

explicit OutDataAnchor(const NodePtr &ownerNode, int idx);

virtual ~OutDataAnchor() = default;
// Get dst in data anchor(one or more)
Vistor<InDataAnchorPtr> GetPeerInDataAnchors() const;
uint32_t GetPeerInDataNodesSize() const;

// Get dst in control anchor(one or more)
Vistor<InControlAnchorPtr> GetPeerInControlAnchors() const;

// Build connection from OutDataAnchor to InDataAnchor
graphStatus LinkTo(const InDataAnchorPtr &dest);

// Build connection from OutDataAnchor to InControlAnchor
graphStatus LinkTo(const InControlAnchorPtr &dest);

protected:
bool Equal(AnchorPtr anchor) const override;
bool IsTypeOf(TYPE type) const override;
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ControlAnchor : public Anchor {
public:
explicit ControlAnchor(const NodePtr &ownerNode);

explicit ControlAnchor(const NodePtr &ownerNode, int idx);

virtual ~ControlAnchor() = default;

protected:
bool IsTypeOf(TYPE type) const override;
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchor : public ControlAnchor {
friend class OutControlAnchor;

friend class OutDataAnchor;

public:
explicit InControlAnchor(const NodePtr &ownerNode);

explicit InControlAnchor(const NodePtr &ownerNode, int idx);

virtual ~InControlAnchor() = default;

// Get source out control anchors
Vistor<OutControlAnchorPtr> GetPeerOutControlAnchors() const;
bool IsPeerOutAnchorsEmpty() const { return peer_anchors_.empty(); }

// Get source out data anchors
Vistor<OutDataAnchorPtr> GetPeerOutDataAnchors() const;

// Build connection from OutControlAnchor to InControlAnchor
graphStatus LinkFrom(const OutControlAnchorPtr &src);

protected:
bool Equal(AnchorPtr anchor) const override;
bool IsTypeOf(TYPE type) const override;
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchor : public ControlAnchor {
friend class InControlAnchor;

public:
template <class T>
using Vistor = RangeVistor<T, std::shared_ptr<ConstAnchor>>;

explicit OutControlAnchor(const NodePtr &ownerNode);

explicit OutControlAnchor(const NodePtr &ownerNode, int idx);

virtual ~OutControlAnchor() = default;

// Get dst in control anchor(one or more)
Vistor<InControlAnchorPtr> GetPeerInControlAnchors() const;
// Get dst data anchor in control anchor(one or more)
Vistor<InDataAnchorPtr> GetPeerInDataAnchors() const;

// Build connection from OutControlAnchor to InControlAnchor
graphStatus LinkTo(const InControlAnchorPtr &dest);
// Build connection from OutDataAnchor to InDataAnchor
graphStatus LinkTo(const InDataAnchorPtr &dest);

protected:
bool Equal(AnchorPtr anchor) const override;
bool IsTypeOf(TYPE type) const override;
};
} // namespace ge
#endif // INC_GRAPH_ANCHOR_H_

+ 0
- 191
metadef/inc/graph/attr_value_serializable.h View File

@@ -1,191 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_
#define INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_

#include <string>
#include <vector>
#include "graph/ge_attr_value.h"

namespace ge {

class GeAttrValue;
class _GeSerializable {
public:
template <typename T>
struct ge_serializable_int64_t_support_type {
using DT = typename std::remove_cv<T>::type;
static const bool value = std::is_same<DT, uint64_t>::value // by cast
|| std::is_same<DT, int32_t>::value || std::is_same<DT, uint32_t>::value ||
std::is_same<DT, int16_t>::value || std::is_same<DT, uint16_t>::value ||
std::is_same<DT, int8_t>::value || std::is_same<DT, uint8_t>::value;
};

template <typename T, typename T::__ge_serializable = 0>
static GeAttrValue SaveItemAsAttrValue(const T &t) {
return GeAttrValue::CreateFrom(t);
}

template <typename T, typename T::__ge_serializable = 0>
static GeAttrValue SaveItemAsAttrValue(const vector<T> &t) {
return GeAttrValue::CreateFrom(t);
}

template <typename T, GeAttrValue::enable_if_type_valid_t<T> = 0, typename DT = typename std::remove_cv<T>::type>
static GeAttrValue SaveItemAsAttrValue(const T &t) {
return GeAttrValue::CreateFrom<DT>(t);
}
// int64_t support type
template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
static GeAttrValue SaveItemAsAttrValue(const T &t) {
return GeAttrValue::CreateFrom<GeAttrValue::INT>(t);
}
// vector int64_t support type
template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
static GeAttrValue SaveItemAsAttrValue(const vector<T> &t) {
return GeAttrValue::CreateFrom<GeAttrValue::LIST_INT>(t);
}

template <typename T, typename T::__ge_serializable = 0>
static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) {
return attrVal.GetValue(t);
}

template <typename T, typename T::__ge_serializable = 0>
static graphStatus LoadItemFromAttrValue(vector<T> &t, GeAttrValue &attrVal) {
return attrVal.GetValue(t);
}

template <typename T, GeAttrValue::enable_if_type_valid_t<T> = 0, typename DT = typename std::remove_cv<T>::type>
static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) {
return attrVal.GetValue<DT>(t);
}

template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) {
return attrVal.GetValue<GeAttrValue::INT>(t);
}

template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
static graphStatus LoadItemFromAttrValue(vector<T> &t, GeAttrValue &attrVal) {
return attrVal.GetValue<GeAttrValue::LIST_INT>(t);
}

template <class T, class... Args>
static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) {
GeAttrValue itemVal = SaveItemAsAttrValue(item);
(void)namedAttrs.SetAttr(itemName, itemVal);
SaveItem(namedAttrs, args...);
}

static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) {}

template <class T, class... Args>
static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) {
auto itemVal = namedAttrs.GetItem(itemName);
auto status = LoadItemFromAttrValue(item, itemVal);
if (status != GRAPH_SUCCESS) {
return status;
}
return LoadItem(namedAttrs, args...);
}

static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) {
return GRAPH_SUCCESS;
}
};

#define _GE_FI(a) #a, a
#define _GE_MAP_FIELDS1(a1) _GE_FI(a1)
#define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2)
#define _GE_MAP_FIELDS3(a1, a2, a3) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3)
#define _GE_MAP_FIELDS4(a1, a2, a3, a4) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4)
#define _GE_MAP_FIELDS5(a1, a2, a3, a4, a5) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5)
#define _GE_MAP_FIELDS6(a1, a2, a3, a4, a5, a6) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6)
#define _GE_MAP_FIELDS7(a1, a2, a3, a4, a5, a6, a7) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7)
#define _GE_MAP_FIELDS8(a1, a2, a3, a4, a5, a6, a7, a8) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8)
#define _GE_MAP_FIELDS9(a1, a2, a3, a4, a5, a6, a7, a8, a9) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9)
#define _GE_MAP_FIELDS10(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10)
#define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
_GE_FI(a11)
#define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
_GE_FI(a11), _GE_FI(a12)
#define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13)
#define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14)
#define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15)

#define _GE_PRIVATE_ARGS_GLUE(x, y) x y

#define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, \
...) \
N
#define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL(args) _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT args
#define _GE_COUNT_MACRO_VAR_ARGS(...) \
_GE_PRIVATE_MACRO_VAR_ARGS_IMPL((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))

#define _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) M##count
#define _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count)
#define _GE_PRIVATE_MACRO_CHOOSE_HELPER(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count)

#define _GE_INVOKE_VAR_MACRO(...) \
_GE_PRIVATE_ARGS_GLUE(_GE_PRIVATE_MACRO_CHOOSE_HELPER(_GE_MAP_FIELDS, _GE_COUNT_MACRO_VAR_ARGS(__VA_ARGS__)), \
(__VA_ARGS__))

#define GE_SERIALIZABLE(...) \
public: \
friend class ge::GeAttrValue; \
using __ge_serializable = int; \
\
private: \
ge::graphStatus Save(GeAttrValue &ar) const { \
GeAttrValue::NAMED_ATTRS named_attrs; \
_GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \
return ar.SetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \
} \
ge::graphStatus Load(const GeAttrValue &ar) { \
GeAttrValue::NAMED_ATTRS named_attrs; \
ge::graphStatus status = ar.GetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \
if (status != GRAPH_SUCCESS) { \
return status; \
} \
return _GeSerializable::LoadItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \
}

// end NamedAttrs Helper: GE_SERIALIZABLE
} // namespace ge
#endif // INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_

+ 0
- 82
metadef/inc/graph/buffer.h View File

@@ -1,82 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_BUFFER_H_
#define INC_GRAPH_BUFFER_H_

#include <graph/types.h>
#include <memory>
#include <string>
#include <vector>
#include "detail/attributes_holder.h"

namespace ge {
#ifdef HOST_VISIBILITY
#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_HOST_VISIBILITY
#endif
#ifdef DEV_VISIBILITY
#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_DEV_VISIBILITY
#endif

using std::shared_ptr;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer {
public:
Buffer();
Buffer(const Buffer &other);

explicit Buffer(std::size_t bufferSize, std::uint8_t defualtVal = 0);

~Buffer() = default;

Buffer &operator=(const Buffer &other);

static Buffer CopyFrom(const std::uint8_t *data, std::size_t bufferSize);

const std::uint8_t *GetData() const;
std::uint8_t *GetData();
std::size_t GetSize() const;
void ClearBuffer();

// For compatibility
inline const std::uint8_t *data() const { return GetData(); }
inline std::uint8_t *data() { return GetData(); } // lint !e659
inline std::size_t size() const { return GetSize(); }
inline void clear() { return ClearBuffer(); }
uint8_t operator[](size_t index) const { // lint !e1022 !e1042
if (buffer_ != nullptr && index < buffer_->size()) { // lint !e574
return (uint8_t)(*buffer_)[index];
}
return 0xff;
}

private:
GeIrProtoHelper<proto::AttrDef> data_;
std::string *buffer_ = nullptr;

// Create from protobuf obj
Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer);
Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer);

friend class GeAttrValueImp;
friend class GeTensor;
};
} // namespace ge
#endif // INC_GRAPH_BUFFER_H_

+ 0
- 308
metadef/inc/graph/compute_graph.h View File

@@ -1,308 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_COMPUTE_GRAPH_H_
#define INC_GRAPH_COMPUTE_GRAPH_H_

#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <deque>
#include "detail/attributes_holder.h"
#include "graph/anchor.h"
#include "graph/node.h"
#include "graph/op_desc.h"
#include "graph/range_vistor.h"

namespace ge {
class Node;
using NodePtr = std::shared_ptr<Node>;
class Edge;
using EdgePtr = std::shared_ptr<Edge>;

class InDataAnchor;
using InDataAnchorPtr = std::shared_ptr<InDataAnchor>;

class OutDataAnchor;
using OutDataAnchorPtr = std::shared_ptr<OutDataAnchor>;

class ControlAnchor;
using ControlAnchorPtr = std::shared_ptr<ControlAnchor>;
class InControlAnchor;
using InControlAnchorPtr = std::shared_ptr<InControlAnchor>;
class OutControlAnchor;
using OutControlAnchorPtr = std::shared_ptr<OutControlAnchor>;
class GeAttrValue;
using AttrValuePtr = std::shared_ptr<GeAttrValue>;
using ConstComputeGraph = const ComputeGraph;

class OperatorImpl;
using OperatorImplPtr = std::shared_ptr<OperatorImpl>;

class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public AttrHolder {
friend class GraphUtils;

public:
template <class T>
using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>;

explicit ComputeGraph(const std::string &name);
~ComputeGraph() override;

std::string GetName() const;
void SetName(const std::string &name);

using AttrHolder::DelAttr;
using AttrHolder::GetAttr;
using AttrHolder::HasAttr;
using AttrHolder::SetAttr;

size_t GetAllNodesSize() const;
Vistor<NodePtr> GetAllNodes() const;
// is_unknown_shape: false, same with GetAllNodes func
// is_unknown_shape: true, same with GetDirectNodes func
Vistor<NodePtr> GetNodes(bool is_unknown_shape) const;
size_t GetDirectNodesSize() const;
Vistor<NodePtr> GetDirectNode() const;
Vistor<NodePtr> GetInputNodes() const;
Vistor<NodePtr> GetOutputNodes() const;

NodePtr FindNode(const std::string &name) const;
NodePtr FindFirstNodeMatchType(const std::string &name) const;
/*lint -e504*/
// AddNode with NodePtr
NodePtr AddNode(NodePtr node);
NodePtr AddNode(OpDescPtr op);
NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize
NodePtr AddNodeFront(NodePtr node);
NodePtr AddNodeFront(const OpDescPtr &op);
NodePtr AddInputNode(NodePtr node);
NodePtr AddOutputNode(NodePtr node);
NodePtr AddOutputNodeByIndex(NodePtr node, int32_t index);
// insert node with specific pre_node
NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node);
NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node);

graphStatus RemoveNode(const NodePtr &node);
graphStatus RemoveInputNode(const NodePtr &node);
graphStatus RemoveOutputNode(const NodePtr &node);
graphStatus RemoveConstInput(const NodePtr &node);

/// Add a subgraph to this graph. The subgraph must has a parent graph and parent node,
/// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph
/// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph()
/// must equal to subgraph->GetOwnerGraph().
/// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph.
/// The subgraph's name SHOULD(not must) be the same as the parameter `name`
graphStatus AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph);
graphStatus AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph);

void RemoveSubgraph(const std::string &name);
void RemoveSubgraph(const std::shared_ptr<ComputeGraph> &subgraph);

std::shared_ptr<ComputeGraph> GetSubgraph(const std::string &name) const;
std::vector<std::shared_ptr<ComputeGraph>> GetAllSubgraphs() const;

// obsolete
std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph);
// obsolete
graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph);

///
/// @brief Update input-mapping
/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input
/// @return graphStatus
///
graphStatus UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping);

///
/// @brief Update output-mapping
/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output
/// @return graphStatus
///
graphStatus UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping);

graphStatus TopologicalSorting();
bool IsValid() const;
void InValid() { is_valid_flag_ = false; }
void Dump() const;

void Swap(ComputeGraph &graph);

graphStatus IsolateNode(const NodePtr &node);
graphStatus Verify();
graphStatus InferShape();
graphStatus InferOriginFormat();
graphStatus InferShapeInNeed();
graphStatus InsertEventNodes();
bool operator==(const ComputeGraph &r_compute_graph) const;

/*lint +e504*/
const std::map<std::vector<std::string>, std::vector<std::string>> &GetShareParamLayer() const {
return params_share_map_;
}

void SetShareParamLayer(const std::map<std::vector<std::string>, std::vector<std::string>> params_share_map) {
params_share_map_ = params_share_map;
}

void SetInputsOrder(const std::vector<std::string> &inputs_order) { inputs_order_ = inputs_order; }

void SetGraphOutNodes(std::map<std::string, std::vector<int32_t>> out_nodes_map) { out_nodes_map_ = out_nodes_map; }

void AppendGraphOutNodes(std::map<std::string, std::vector<int32_t>> out_nodes_map) {
for (auto &item : out_nodes_map) {
(void)out_nodes_map_.emplace(item.first, item.second);
}
}

shared_ptr<ComputeGraph> GetParentGraph();
void SetParentGraph(const shared_ptr<ComputeGraph> &parent);
shared_ptr<Node> GetParentNode();
void SetParentNode(const shared_ptr<Node> &parent);

const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; }

void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; }

ComputeGraphPtr GetOrigGraph(void) { return origGraph_; }
void SetOutputSize(uint32_t size) { output_size_ = size; }
uint32_t GetOutputSize() const { return output_size_; }
void SetInputSize(uint32_t size) { input_size_ = size; }
uint32_t GetInputSize() const { return input_size_; }

// false: known shape true: unknow shape
bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; }
void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; }

///
/// Set is need train iteration.
/// If set true, it means this graph need to be run iteration some
/// times(according variant "npu_runconfig/iterations_per_loop").
/// @param need_iteration is need iteration
///
void SetNeedIteration(bool need_iteration) { need_iteration_ = need_iteration; }

void SetUserDefOutput(const std::string &output_name);

const std::string GetOutput();

///
/// Get is need train iteration.
/// @return is need iteration
///
bool GetNeedIteration() const { return need_iteration_; }

void SetGraphOpName(const std::map<uint32_t, std::string> &op_name_map) { op_name_map_ = op_name_map; }
const std::map<uint32_t, std::string> &GetGraphOpName() const { return op_name_map_; }

const std::map<OperatorImplPtr, NodePtr> &GetAllNodesInfo() const;

void SetAllNodesInfo(const std::map<OperatorImplPtr, NodePtr> &nodes) { all_nodes_infos_ = nodes; }

void SetGraphOutNodesInfo(std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info) {
output_nodes_info_ = out_nodes_info;
}

void AppendGraphOutNodesInfo(std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info) {
output_nodes_info_.insert(output_nodes_info_.end(), out_nodes_info.begin(), out_nodes_info.end());
}

const std::vector<std::pair<NodePtr, int32_t>> &GetGraphOutNodesInfo() const { return output_nodes_info_; }

void SetGraphTargetNodesInfo(const std::vector<NodePtr> &target_nodes_info) {
target_nodes_info_ = target_nodes_info;
}
const std::vector<NodePtr> &GetGraphTargetNodesInfo() const { return target_nodes_info_; }

void SetSessionID(uint64_t session_id) { session_id_ = session_id; }
uint64_t GetSessionID() const { return session_id_; }

void SetGraphID(uint32_t graph_id) { graph_id_ = graph_id; }
uint32_t GetGraphID() const { return graph_id_; }

void SaveDataFormat(ge::Format data_format) { data_format_ = data_format; }
ge::Format GetDataFormat() const { return data_format_; }
bool IsSummaryGraph() const { return is_summary_graph_; }
void SetSummaryFlag(bool is_summary_graph) { is_summary_graph_ = is_summary_graph; }
// Graph Before BFE
ComputeGraphPtr origGraph_;

protected:
ProtoAttrMapHelper MutableAttrMap() override;
ConstProtoAttrMapHelper GetAttrMap() const override;

private:
graphStatus DFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::vector<NodePtr> &stack);
graphStatus BFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::deque<NodePtr> &stack);
graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::map<string, NodePtr> &breadth_node_map);
graphStatus TopologicalSortingGraph();
graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum);
Vistor<NodePtr> AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const;
size_t GetInEdgeSize(const NodePtr &node);
size_t GetOutEdgeSize(const NodePtr &node);
graphStatus RemoveExtraOutEdge(const NodePtr &node);
bool GraphMembersAreEqual(const ComputeGraph &r_graph) const;
bool GraphAttrsAreEqual(const ComputeGraph &r_graph) const;
bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector,
const std::vector<NodePtr> &l_node_ptr_vector) const;

void SetNodesOwner();

friend class ModelSerializeImp;
friend class GraphDebugImp;
friend class OnnxUtils;
friend class TuningUtils;

std::string name_;
uint32_t graph_id_ = 0;
ProtoAttrMapHelper attrs_;
std::vector<NodePtr> nodes_;
std::map<OperatorImplPtr, NodePtr> all_nodes_infos_;
std::vector<NodePtr> target_nodes_info_;

std::vector<NodePtr> input_nodes_;
std::vector<std::string> inputs_order_;
uint32_t input_size_ = 1;
std::map<std::string, std::vector<int32_t>> out_nodes_map_;
uint32_t output_size_ = 1;
std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_;

std::vector<std::shared_ptr<ComputeGraph>> sub_graph_;
std::map<std::string, std::shared_ptr<ComputeGraph>> names_to_subgraph_;
std::weak_ptr<ComputeGraph> parent_graph_;
std::weak_ptr<Node> parent_node_;

// the members followed should not in the ComputeGraph class
bool is_valid_flag_;
bool is_summary_graph_ = false;
// Indicates whether it is need iteration
bool need_iteration_ = false;
std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_;
// TaskIdx -> op_name Map
std::map<uint32_t, std::string> op_name_map_;
uint64_t session_id_ = 0;
ge::Format data_format_ = ge::FORMAT_ND;
// unknown graph indicator, default is false, mean known shape
bool is_unknown_shape_graph_ = false;
};
} // namespace ge
#endif // INC_GRAPH_COMPUTE_GRAPH_H_

+ 0
- 1122
metadef/inc/graph/debug/ge_attr_define.h
File diff suppressed because it is too large
View File


+ 0
- 195
metadef/inc/graph/def_types.h View File

@@ -1,195 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_DEF_TYPES_H_
#define INC_GRAPH_DEF_TYPES_H_

#include <atomic>
#include <memory>
#include <vector>
#include "graph/attr_value_serializable.h"
#include "graph/buffer.h"
namespace ge {
#define DEF_TYPE_DEC(type, name) \
inline void set_##name(const type &value) { name = value; } \
type *mutable_##name() { return &name; }

#define DEF_TYPE_HAS_DEC(type, name) \
inline void set_##name(const type &value) { name = value; } \
\
private: \
bool has_mutable_##name{false}; \
\
public: \
bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \
type *mutable_##name() { \
has_mutable_##name = true; \
return &name; \
}

#define DEF_TYPE_VEC_DEC(type, name) \
inline int name##_size() const { return name.size(); } \
inline void clear_##name() { name.clear(); } \
inline void set_##name(int index, type value) { name[index] = value; } \
inline void add_##name(type value) { name.push_back(value); } \
inline std::vector<type> *mutable_##name() { return &name; }

#define DEF_TYPE_BYTES_DEC(name) \
inline void clear_##name() { name.ClearBuffer(); } \
inline void set_##name(const void *value, size_t size) { name = Buffer::CopyFrom((const uint8_t *)(value), size); } \
inline Buffer *mutable_##name() { return &name; }

struct CompressInfo {
public:
CompressInfo() {}
CompressInfo(int32_t blockRow, int32_t blockCol, int32_t fractalK, int32_t fractalN, int32_t lastFractalK,
int32_t lastFractalN, int32_t cubeSize, int32_t loadDir) {
blockrow = blockRow;
blockcol = blockCol;
fractalk = fractalK;
fractaln = fractalN;
lastfractalk = lastFractalK;
lastfractaln = lastFractalN;
cubesize = cubeSize;
loaddir = loadDir;
}

int32_t blockrow{0}; // Block row
int32_t blockcol{0}; // Block col
int32_t fractalk{0}; // Fractal K
int32_t fractaln{0}; // Fractal N
int32_t lastfractalk{0}; // K of last fractal
int32_t lastfractaln{0}; // N of last fractal
int32_t cubesize{0}; // Cube's length
int32_t loaddir{0}; // Data load directtiono 0:col load 1:row load
DEF_TYPE_DEC(int32_t, blockrow);
DEF_TYPE_DEC(int32_t, blockcol);
DEF_TYPE_DEC(int32_t, fractalk);
DEF_TYPE_DEC(int32_t, fractaln);
DEF_TYPE_DEC(int32_t, lastfractalk);
DEF_TYPE_DEC(int32_t, lastfractaln);
DEF_TYPE_DEC(int32_t, cubesize);
DEF_TYPE_DEC(int32_t, loaddir);

GE_SERIALIZABLE(blockrow, blockcol, fractalk, fractaln, lastfractalk, lastfractaln, cubesize, loaddir);
};

enum QuantizeScaleType { VECTOR_SCALE = 0, SCALAR_SCALE = 1 };
enum QuantizeScaleMode { NORMAL_MODE = 0, SQRT_MODE = 1 };
enum QuantizeAlgorithm {
NON_OFFSET_ALGO = 0,
HALF_OFFSET_ALGO = 1,
ALL_OFFSET_ALGO = 2,
};
struct QuantizeFactor {
public:
// QuantizeScaleMode scale_mode;
uint32_t scale_mode{0};
Buffer scale_value;
int64_t scale_offset{0};
Buffer offset_data_value;
int64_t offset_data_offset{0};
Buffer offset_weight_value;
int64_t offset_weight_offset{0};
Buffer offset_pad_value;
int64_t offset_pad_offset{0};

DEF_TYPE_DEC(uint32_t, scale_mode);
DEF_TYPE_BYTES_DEC(scale_value);

DEF_TYPE_DEC(int64_t, scale_offset);
DEF_TYPE_BYTES_DEC(offset_data_value);
DEF_TYPE_DEC(int64_t, offset_data_offset);

DEF_TYPE_BYTES_DEC(offset_weight_value);
DEF_TYPE_DEC(int64_t, offset_weight_offset);
DEF_TYPE_BYTES_DEC(offset_pad_value);
DEF_TYPE_DEC(int64_t, offset_pad_offset);

GE_SERIALIZABLE(scale_mode, scale_value, scale_offset, offset_data_value, offset_data_offset, offset_weight_value,
offset_weight_offset, offset_pad_value, offset_pad_offset)
};

static inline bool QuantizeFactorHasData(const QuantizeFactor &factor) {
return factor.scale_value.GetSize() > 0 || factor.offset_data_value.GetSize() > 0 ||
factor.offset_weight_value.GetSize() > 0 || factor.offset_pad_value.GetSize() > 0;
}

struct AllOffsetQuantizeInfo {
public:
AllOffsetQuantizeInfo() {}
AllOffsetQuantizeInfo(float s, int32_t o) : scale(s), offset(o) {}
float scale{0};
int32_t offset{0};

DEF_TYPE_DEC(float, scale);
DEF_TYPE_DEC(int32_t, offset);

GE_SERIALIZABLE(scale, offset)
};

struct QuantizeCalcFactor {
public:
Buffer offsetw;
int64_t offsetw_offset{0};
Buffer offsetd;
int64_t offsetd_offset{0};
Buffer scalereq;
int64_t scaledreq_offset{0};
Buffer offsetdnext;
int64_t offsetdnext_offset{0};

DEF_TYPE_BYTES_DEC(offsetw);
DEF_TYPE_DEC(int64_t, offsetw_offset);
DEF_TYPE_BYTES_DEC(offsetd);
DEF_TYPE_DEC(int64_t, offsetd_offset);
DEF_TYPE_BYTES_DEC(scalereq);
DEF_TYPE_DEC(int64_t, scaledreq_offset);
DEF_TYPE_BYTES_DEC(offsetdnext);
DEF_TYPE_DEC(int64_t, offsetdnext_offset);

GE_SERIALIZABLE(offsetw, offsetw_offset, offsetd, offsetd_offset, scalereq, scaledreq_offset, offsetdnext,
offsetdnext_offset);
};

static inline bool QuantizeFactorHasData(const QuantizeCalcFactor &factor) {
return factor.offsetw.GetSize() > 0 || factor.offsetd.GetSize() > 0 || factor.scalereq.GetSize() > 0 ||
factor.offsetdnext.GetSize() > 0;
}

struct QuantizeFactorParams {
uint32_t quantize_algo{0};
uint32_t scale_type{0};
QuantizeFactor quantize_param;
QuantizeFactor dequantize_param;
QuantizeFactor requantize_param;
QuantizeCalcFactor quantizecalc_param;
DEF_TYPE_DEC(uint32_t, quantize_algo);
DEF_TYPE_DEC(uint32_t, scale_type);
DEF_TYPE_HAS_DEC(QuantizeFactor, quantize_param);
DEF_TYPE_HAS_DEC(QuantizeFactor, dequantize_param);
DEF_TYPE_HAS_DEC(QuantizeFactor, requantize_param);
DEF_TYPE_HAS_DEC(QuantizeCalcFactor, quantizecalc_param);

GE_SERIALIZABLE(quantize_algo, scale_type, quantize_param, dequantize_param, requantize_param, quantizecalc_param,
has_mutable_quantize_param, has_mutable_dequantize_param, has_mutable_requantize_param,
has_mutable_quantizecalc_param);
};

#undef DEF_TYPE_DEC
} // namespace ge

#endif // INC_GRAPH_DEF_TYPES_H_

+ 0
- 120
metadef/inc/graph/detail/any_map.h View File

@@ -1,120 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_DETAIL_ANY_MAP_H_
#define INC_GRAPH_DETAIL_ANY_MAP_H_

#include <map>
#include <memory>
#include <string>
#include <utility>

namespace ge {
using std::shared_ptr;
using std::string;

class TypeID {
public:
template <class T>
static TypeID Of() {
return TypeID(__PRETTY_FUNCTION__);
}

~TypeID() = default;

bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; }

private:
explicit TypeID(string type) : type_(std::move(type)) {} // lint !e30 !e32

string type_;
};

class AnyMap {
public:
template <class DT>
bool Set(const string &name, const DT &val);

template <class T>
bool Get(const string &name, T &retValue) const;

bool Has(const string &name) const { return anyValues_.find(name) != anyValues_.end(); }

void Swap(AnyMap &other) { anyValues_.swap(other.anyValues_); }

private:
class Placeholder {
public:
virtual ~Placeholder() = default;

virtual const TypeID &GetTypeInfo() const = 0;
};

template <typename VT>
class Holder : public Placeholder {
public:
explicit Holder(const VT &value) : value_(value) {}

~Holder() override = default;

const TypeID &GetTypeInfo() const override {
static const TypeID typeId = TypeID::Of<VT>();
return typeId;
}

const VT value_;
};

std::map<string, shared_ptr<Placeholder>> anyValues_;
};

template <class DT>
bool AnyMap::Set(const string &name, const DT &val) {
auto it = anyValues_.find(name);

std::shared_ptr<Holder<DT>> tmp;
try {
tmp = std::make_shared<Holder<DT>>(val);
} catch (std::bad_alloc &e) {
tmp = nullptr;
} catch (...) {
tmp = nullptr;
}

if (it == anyValues_.end()) {
(void)anyValues_.emplace(name, tmp);
} else {
if (it->second && it->second->GetTypeInfo() == TypeID::Of<DT>()) {
it->second = tmp;
} else {
return false;
}
}
return true;
}

template <class T>
bool AnyMap::Get(const string &name, T &retValue) const {
auto it = anyValues_.find(name);
if (it != anyValues_.end() && it->second && it->second->GetTypeInfo() == TypeID::Of<T>()) {
auto retPtr = std::static_pointer_cast<Holder<T>>(it->second);
retValue = retPtr->value_;
return true;
}
return false;
}
} // namespace ge
#endif // INC_GRAPH_DETAIL_ANY_MAP_H_

+ 0
- 165
metadef/inc/graph/detail/attributes_holder.h View File

@@ -1,165 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_
#define INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_

#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "graph/detail/any_map.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"

namespace google {
namespace protobuf {
class Message;
template <typename Key, typename T>
class Map;
} // namespace protobuf
} // namespace google

namespace ge {
using std::string;
class GeAttrValue;

namespace proto {
class AttrDef;
class TensorDef;
class TensorDescriptor;
class ShapeDef;
class NamedAttrs;
class ModelDef;
class OpDef;
class GraphDef;
} // namespace proto

using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; // lint !e1073
using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>;

template <class ProtoType>
class GeIrProtoHelper {
public:
GeIrProtoHelper(const ProtoMsgOwner &protoOwner, ProtoType *protoMsg)
: protoOwner_(protoOwner), protoMsg_(protoMsg) {}

GeIrProtoHelper() {
protoOwner_ = std::shared_ptr<::google::protobuf::Message>(nullptr);
protoMsg_ = nullptr;
}
virtual ~GeIrProtoHelper() = default;

template <typename T>
GeIrProtoHelper(const GeIrProtoHelper<T> &other) {
protoOwner_ = other.protoOwner_;
protoMsg_ = other.protoMsg_;
}
template <typename T>
GeIrProtoHelper &operator=(const GeIrProtoHelper<T> &other) {
protoOwner_ = other.protoOnwer_;
protoMsg_ = other.protoMsg_;
return *this;
}
void InitDefault();
template <typename T>
bool operator==(const GeIrProtoHelper<T> &other) const {
return protoOwner_ == other.protoOwner_ && protoMsg_ == other.protoMsg_;
}

inline const ProtoMsgOwner &GetProtoOwner() const { return protoOwner_; }
inline ProtoType *GetProtoMsg() const { return protoMsg_; }
void CopyValueFrom(const GeIrProtoHelper<const ProtoType> &other) {
if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) {
*protoMsg_ = *other.protoMsg_;
}
}
void MoveValueFrom(GeIrProtoHelper<ProtoType> &&other) {
if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) {
*protoMsg_ = std::move(*other.protoMsg_);
}
}

void Swap(GeIrProtoHelper<ProtoType> &other) {
protoOwner_.swap(other.protoOwner_);

ProtoType *temp = protoMsg_;
protoMsg_ = other.protoMsg_;
other.protoMsg_ = temp;
}

// protoMsg_ is part of protoOwner_, they have the same runtime
ProtoMsgOwner protoOwner_ = nullptr;
ProtoType *protoMsg_ = nullptr;
friend class GeIrProtoHelper<typename std::conditional<
std::is_const<ProtoType>::value, typename std::remove_const<ProtoType>::type, const ProtoType>::type>;
};

using ProtoAttrMapHelper = GeIrProtoHelper<ProtoAttrMap>;
using ConstProtoAttrMapHelper = GeIrProtoHelper<const ProtoAttrMap>;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder {
public:
AttrHolder() = default;
virtual ~AttrHolder() = default;

graphStatus SetAttr(const string &name, const GeAttrValue &value);

graphStatus GetAttr(const string &name, GeAttrValue &value) const;

bool HasAttr(const string &name) const;

graphStatus DelAttr(const string &name);

void CopyAttrsFrom(const AttrHolder &holder);

void Swap(AttrHolder &holder) {
requiredAttrs_.swap(holder.requiredAttrs_);
extAttrs_.Swap(holder.extAttrs_);
}

template <class T>
bool SetExtAttr(const string &name, const T &value) {
return extAttrs_.Set(name, value);
}
template <class T>
T TryGetExtAttr(const string &name, T defaultValue) const {
T ret(defaultValue);
(void)extAttrs_.Get(name, ret);
return ret;
}

protected:
graphStatus AddRequiredAttr(const std::string &name);
const std::unordered_set<string> GetAllAttrNames() const;
const std::map<string, GeAttrValue> GetAllAttrs() const; // lint !e1073

virtual ProtoAttrMapHelper MutableAttrMap() = 0;
virtual ConstProtoAttrMapHelper GetAttrMap() const = 0;

friend class ModelSerializeImp;
friend class AttrUtils;
friend class AttrUtilsHelper;

std::vector<string> requiredAttrs_;

private:
AnyMap extAttrs_;
};
} // namespace ge
#endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_

+ 0
- 93
metadef/inc/graph/detail/model_serialize_imp.h View File

@@ -1,93 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_
#define INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_

#include <map>
#include <memory>
#include <string>
#include <vector>
#include "graph/anchor.h"
#include "graph/detail/attributes_holder.h"
#include "graph/ge_tensor.h"
#include "graph/graph.h"
#include "graph/node.h"

namespace ge {
using ComputeGraphPtr = std::shared_ptr<ComputeGraph>;

struct NodeNameGraphReq {
string node_name;
int32_t index;
ComputeGraphPtr graph;
};

struct NodeNameNodeReq {
string src_node_name;
int32_t src_out_index;
NodePtr dst_node;
int32_t dst_in_index;
string dst_node_name;
};

class ModelSerializeImp {
public:
bool SerializeModel(const Model &model, proto::ModelDef *modeProto, bool is_dump = false);

bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto, bool is_dump = false);

bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto);

bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto, bool is_dump = false);

bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto, bool is_dump = false);

bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto);

bool UnserializeModel(Model &model, proto::ModelDef &modeProto);

bool UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graphProto);

bool UnserializeGraph(ComputeGraphPtr &graph, proto::GraphDef &graphProto);

bool HandleNodeNameRef();

bool UnserializeOpDesc(OpDescPtr &opDesc, proto::OpDef &opDefProto);
void AttrDefToOpDesc(OpDescPtr &op_desc, std::vector<string> &key_in, std::vector<string> &key_out,
std::vector<uint32_t> &value_in, std::vector<uint32_t> &value_out, std::vector<string> &opt);
void OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto);

bool UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &opDefProto);

bool UnserializeTensor(GeTensorPtr &tensor, proto::TensorDef &tensorProto);

bool ParseNodeIndex(const string &node_index, string &nodeName, int32_t &index);

void SetProtobufOwner(const ProtoMsgOwner &bufferProtobufOnwer) { protobuf_owner_ = bufferProtobufOnwer; }

private:
bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map<std::string, ComputeGraphPtr> &subgraphs);

std::vector<NodeNameGraphReq> graph_input_node_names_;
std::vector<NodeNameGraphReq> graph_output_node_names_;
std::vector<NodeNameNodeReq> node_input_node_names_;
std::map<string, NodePtr> node_map_;
ProtoMsgOwner protobuf_owner_;
};
} // namespace ge

#endif // INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_

+ 0
- 343
metadef/inc/graph/ge_attr_value.h View File

@@ -1,343 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_GE_ATTR_VALUE_H_
#define INC_GRAPH_GE_ATTR_VALUE_H_

#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "graph/buffer.h"
#include "detail/attributes_holder.h"
#include "graph/ge_error_codes.h"
#include "graph/ge_tensor.h"

using std::map;
using std::string;
using std::vector;

namespace ge {
class GeTensor;

using GeTensorPtr = std::shared_ptr<GeTensor>;
using ConstGeTensorPtr = std::shared_ptr<const GeTensor>;

class ComputeGraph;
using ComputeGraphPtr = std::shared_ptr<ComputeGraph>;
using ConstComputeGraphPtr = std::shared_ptr<const ComputeGraph>;

class GeTensorDesc;
class GeAttrValue;
class GeAttrValueImp;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder {
public:
NamedAttrs();
virtual ~NamedAttrs() = default;
void SetName(const std::string &name);
string GetName() const;
GeAttrValue GetItem(const string &key) const;

protected:
ProtoAttrMapHelper MutableAttrMap() override;
ConstProtoAttrMapHelper GetAttrMap() const override;

private:
// Create namedAttrs from protobuf obj
NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg);
GeIrProtoHelper<proto::NamedAttrs> named_attrs_;
friend class GeAttrValueImp;
friend class GeAttrValue;
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue {
public:
using INT = int64_t;
using FLOAT = float;
using BOOL = bool;
using STR = std::string;
using TENSOR = GeTensorPtr;
using TENSOR_DESC = GeTensorDesc;
using GRAPH = ComputeGraphPtr;
using BYTES = Buffer;
using NAMED_ATTRS = ge::NamedAttrs;
using DATA_TYPE = ge::DataType;

using LIST_INT = vector<INT>;
using LIST_FLOAT = vector<FLOAT>;
using LIST_BOOL = vector<BOOL>;
using LIST_STR = vector<STR>;
using LIST_TENSOR = vector<TENSOR>;
using LIST_TENSOR_DESC = vector<TENSOR_DESC>;
using LIST_GRAPH = vector<GRAPH>;
using LIST_BYTES = vector<BYTES>;
using LIST_NAMED_ATTRS = vector<NAMED_ATTRS>;
using LIST_LIST_INT = vector<vector<int64_t>>;
using LIST_DATA_TYPE = vector<ge::DataType>;

using NamedAttrs = ge::NamedAttrs; // for cce use (ge::GeAttrValue::NamedAttrs).

enum ValueType {
VT_NONE = 0,
VT_STRING,
VT_FLOAT,
VT_BOOL,
VT_INT,
VT_TENSOR_DESC,
VT_TENSOR,
VT_BYTES,
VT_GRAPH,
VT_NAMED_ATTRS,
VT_LIST_LIST_INT,
VT_DATA_TYPE,

VT_LIST_BASE = 1000,
VT_LIST_STRING = VT_LIST_BASE + VT_STRING,
VT_LIST_FLOAT = VT_LIST_BASE + VT_FLOAT,
VT_LIST_BOOL = VT_LIST_BASE + VT_BOOL,
VT_LIST_INT = VT_LIST_BASE + VT_INT,
VT_LIST_TENSOR_DESC = VT_LIST_BASE + VT_TENSOR_DESC,
VT_LIST_TENSOR = VT_LIST_BASE + VT_TENSOR,
VT_LIST_BYTES = VT_LIST_BASE + VT_BYTES,
VT_LIST_GRAPH = VT_LIST_BASE + VT_GRAPH,
VT_LIST_NAMED_ATTRS = VT_LIST_BASE + VT_NAMED_ATTRS,
VT_LIST_DATA_TYPE = VT_LIST_BASE + VT_DATA_TYPE,
};

template <class T>
struct IsAttrTypeEnable {
using DT = typename std::remove_cv<T>::type;

static bool const VALUE = std::is_same<INT, DT>::value || std::is_same<FLOAT, DT>::value ||
std::is_same<BOOL, DT>::value || std::is_same<STR, DT>::value ||
std::is_same<GRAPH, DT>::value || std::is_same<TENSOR, DT>::value ||
std::is_same<TENSOR_DESC, DT>::value || std::is_same<BYTES, DT>::value ||
std::is_same<NAMED_ATTRS, DT>::value || std::is_same<DATA_TYPE, DT>::value;

// Not has list type of NamedAttrs
static bool const LIST_VALUE = std::is_same<LIST_INT, DT>::value || std::is_same<LIST_FLOAT, DT>::value ||
std::is_same<LIST_BOOL, DT>::value || std::is_same<LIST_STR, DT>::value ||
std::is_same<LIST_GRAPH, DT>::value || std::is_same<LIST_TENSOR, DT>::value ||
std::is_same<LIST_TENSOR_DESC, DT>::value || std::is_same<LIST_BYTES, DT>::value ||
std::is_same<LIST_NAMED_ATTRS, DT>::value ||
std::is_same<LIST_LIST_INT, DT>::value || std::is_same<LIST_DATA_TYPE, DT>::value;
};

template <typename vector_type>
// To cols
using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, int>::type;

template <typename one_type>
using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type;

template <typename val_type>
using enable_if_type_valid_t =
typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type;

template <typename seriliable_type>
using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable;

GeAttrValue();
~GeAttrValue() = default;
// SetValue, Set initializer_list
template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
graphStatus SetValue(std::initializer_list<DT> &&val) {
T vectorVal;
for (auto &item : val) {
vectorVal.push_back(item);
}
return SetValue(vectorVal);
}

// SetValue, Set vector
template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
graphStatus SetValue(const std::vector<DT> &val) {
T vectorVal;
for (auto item : val) {
vectorVal.push_back(item);
}
return SetValue(vectorVal);
}

// SetValue, not list type
template <typename T, typename DT, enable_if_one_type_valid_t<T> = 0>
graphStatus SetValue(DT &&val) {
return SetValue(T(std::forward<DT>(val)));
}

// GE_SERIALIZABLE
template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
graphStatus SetValue(const T &t) {
return t.Save(*this);
}

template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
graphStatus SetValue(const vector<T> &t) {
vector<NamedAttrs> attrs;
for (auto &item : t) {
GeAttrValue val;
item.Save(val);
NamedAttrs attrsItem;
(void)val.GetValue<NamedAttrs>(attrsItem);
attrs.push_back(attrsItem);
}
return SetValue(attrs);
}

// GetValue, list value
template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0,
typename std::enable_if<!std::is_same<DT, GeTensorPtr>::value, int>::type = 0>
graphStatus GetValue(std::vector<DT> &val) const {
T valGet;
val.clear();
auto status = GetValue(valGet);
if (status != GRAPH_SUCCESS) {
return status;
}
for (auto item : valGet) {
val.push_back(item);
}
return GRAPH_SUCCESS;
}

// GetValue, not list type
template <typename T, typename DT, enable_if_one_type_valid_t<T> = 0,
typename std::enable_if<!std::is_same<DT, GeTensorPtr>::value, int>::type = 0>
graphStatus GetValue(DT &val) const {
T valGet;
auto status = GetValue(valGet);
if (status != GRAPH_SUCCESS) {
return status;
}
val = DT(valGet);
return GRAPH_SUCCESS;
}

// GE_SERIALIZABLE
template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
graphStatus GetValue(T &t) {
return t.Load(*this);
}

template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
graphStatus GetValue(vector<T> &t) {
graphStatus status;
t.clear();
vector<NamedAttrs> attrs;
status = this->GetValue(attrs);
if (status != GRAPH_SUCCESS) {
return status;
}
for (auto &attr : attrs) {
T item;
GeAttrValue val;
(void)val.SetValue(attr);
status = item.Load(val);
if (status != GRAPH_SUCCESS) {
return status;
}
t.push_back(item);
}
return GRAPH_SUCCESS;
}

template <typename T, typename DT, enable_if_type_valid_t<T> = 0>
static GeAttrValue CreateFrom(DT &&val) {
GeAttrValue valRet;
(void)valRet.SetValue<T>(std::forward<DT>(val));
return valRet;
}

template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
static GeAttrValue CreateFrom(std::initializer_list<DT> &&val) {
GeAttrValue valRet;
(void)valRet.SetValue<T>(std::move(val));
return valRet;
}

template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
static GeAttrValue CreateFrom(const T &val) {
GeAttrValue valRet;
(void)valRet.SetValue(val);
return valRet;
}

template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
static GeAttrValue CreateFrom(const vector<T> &val) {
GeAttrValue valRet;
(void)valRet.SetValue(val);
return valRet;
}

ValueType GetValueType() const;

bool IsEmpty() const;

GeAttrValue Copy() const;

// For map key
bool operator==(const GeAttrValue &other) const { return value_ == other.value_; }

graphStatus MutableTensor(GeTensorPtr &tensor);
graphStatus MutableListTensor(vector<GeTensorPtr> &list_tensor);

private:
#define VALUE_SET_GET_DEC(DT) \
graphStatus SetValue(const DT &val); \
graphStatus GetValue(DT &val) const;
VALUE_SET_GET_DEC(GeAttrValue::STR)
VALUE_SET_GET_DEC(GeAttrValue::INT)
VALUE_SET_GET_DEC(GeAttrValue::FLOAT)
VALUE_SET_GET_DEC(GeAttrValue::BOOL)
VALUE_SET_GET_DEC(GeTensorDesc)
VALUE_SET_GET_DEC(GeAttrValue::TENSOR)
VALUE_SET_GET_DEC(GeAttrValue::GRAPH)
VALUE_SET_GET_DEC(BYTES)
VALUE_SET_GET_DEC(NamedAttrs)
VALUE_SET_GET_DEC(ge::DataType) // lint !e665
VALUE_SET_GET_DEC(vector<GeAttrValue::STR>)
VALUE_SET_GET_DEC(vector<GeAttrValue::INT>)
VALUE_SET_GET_DEC(vector<GeAttrValue::FLOAT>)
VALUE_SET_GET_DEC(vector<GeAttrValue::BOOL>)
VALUE_SET_GET_DEC(vector<GeTensorDesc>)
VALUE_SET_GET_DEC(vector<GeAttrValue::TENSOR>)
VALUE_SET_GET_DEC(vector<GeAttrValue::GRAPH>)
VALUE_SET_GET_DEC(vector<GeAttrValue::BYTES>)
VALUE_SET_GET_DEC(vector<NamedAttrs>)
VALUE_SET_GET_DEC(vector<vector<int64_t>>) // lint !e665
VALUE_SET_GET_DEC(vector<ge::DataType>) // lint !e665
#undef VALUE_SET_GET_DEC

GeIrProtoHelper<proto::AttrDef> value_;
GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val);

friend class AttrHolder;
friend class ModelSerializeImp;
friend class OnnxUtils;
};

class AttrValueImpl {
public:
AttrValueImpl() = default;
~AttrValueImpl() = default;

GeAttrValue geAttrValue_;
};
} // namespace ge
#endif // INC_GRAPH_GE_ATTR_VALUE_H_

+ 0
- 46
metadef/inc/graph/ge_context.h View File

@@ -1,46 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_GE_CONTEXT_H_
#define INC_GRAPH_GE_CONTEXT_H_

#include <string>
#include "graph/ge_error_codes.h"

namespace ge {
class GEContext {
public:
graphStatus GetOption(const std::string &key, std::string &option);
bool GetHostExecFlag();
uint64_t SessionId();
uint32_t DeviceId();
uint64_t TraceId();
void Init();
void SetSessionId(uint64_t session_id);
void SetCtxDeviceId(uint32_t device_id);

private:
uint64_t session_id_ = 0;
uint32_t device_id_ = 0;
uint64_t trace_id_ = 0;
}; // class GEContext

/// Get context
/// @return
GEContext &GetContext();
} // namespace ge

#endif // INC_GRAPH_GE_CONTEXT_H_

+ 0
- 26
metadef/inc/graph/ge_global_options.h View File

@@ -1,26 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_GE_GLOBAL_OPTIONS_H_
#define INC_GRAPH_GE_GLOBAL_OPTIONS_H_

#include <map>
#include <string>

namespace ge {
std::map<std::string, std::string> &GetMutableGlobalOptions();
}
#endif // INC_GRAPH_GE_GLOBAL_OPTIONS_H_

+ 0
- 44
metadef/inc/graph/ge_local_context.h View File

@@ -1,44 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_GE_LOCAL_CONTEXT_H_
#define INC_GRAPH_GE_LOCAL_CONTEXT_H_

#include <map>
#include <string>
#include <vector>
#include "graph/ge_error_codes.h"

using std::map;
using std::string;

namespace ge {
class GEThreadLocalContext {
public:
graphStatus GetOption(const string &key, string &option);
void SetGraphOption(map<std::string, string> options_map);
void SetSessionOption(map<std::string, string> options_map);
void SetGlobalOption(map<std::string, string> options_map);

private:
map<string, string> graph_options_;
map<string, string> session_options_;
map<string, string> global_options_;
}; // class GEThreadLocalContext

GEThreadLocalContext &GetThreadLocalContext();
} // namespace ge
#endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_

+ 0
- 193
metadef/inc/graph/ge_tensor.h View File

@@ -1,193 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_GE_TENSOR_H_
#define INC_GRAPH_GE_TENSOR_H_

#include <atomic>
#include <memory>
#include <string>
#include <vector>
#include "detail/attributes_holder.h"
#include "graph/buffer.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"

namespace ge {
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape {
public:
GeShape();
~GeShape() = default;
explicit GeShape(std::vector<int64_t> s);

size_t GetDimNum() const;
// If the idx is invalid, return 0
int64_t GetDim(size_t idx) const;
graphStatus SetDim(size_t idx, int64_t value);
std::vector<int64_t> GetDims() const;

int64_t GetShapeSize() const;
std::string ToString() const;

///
/// @brief Check is unknown shape
/// @return bool
///
bool IsUnknownShape() const;

///
/// @brief Check is a scalar
/// @return bool
///
bool IsScalar() const;

GeShape(const GeShape &other);
GeShape(GeShape &&other);
GeShape &operator=(const GeShape &other);
GeShape &operator=(GeShape &&other);

private:
GeIrProtoHelper<proto::ShapeDef> shape_def_;
friend class GeTensorDesc;
// Create from proto obj
GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg);

void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; }
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrHolder {
friend class TensorUtils;
friend class GeAttrValue;
friend class ModelSerialize;

public:
GeTensorDesc();
explicit GeTensorDesc(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);
GeTensorDesc(const GeTensorDesc &desc);
GeTensorDesc(GeTensorDesc &&desc);

~GeTensorDesc() = default;
bool operator==(const GeTensorDesc &r_ge_tensor_desc) const;

void Update(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);

GeShape GetShape() const;
GeShape &MutableShape();
void SetShape(GeShape shape);

// set shape with -2, it stand for unknown shape
void SetUnknownDimNumShape();
// for unknown shape
graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range);
graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const;

GeShape GetOriginShape() const;
void SetOriginShape(const GeShape &originShape);

Format GetFormat() const;
void SetFormat(Format format);

Format GetOriginFormat() const;
void SetOriginFormat(Format originFormat);

void SetName(const std::string &name);
const std::string GetName() const;

DataType GetDataType() const;
void SetDataType(DataType dt);

DataType GetOriginDataType() const;
void SetOriginDataType(DataType originDataType);

std::vector<uint32_t> GetRefPortIndex() const;
void SetRefPortByIndex(const std::vector<uint32_t> &index);

GeTensorDesc Clone() const;
GeTensorDesc &operator=(const GeTensorDesc &desc);
GeTensorDesc &operator=(GeTensorDesc &&desc);

graphStatus IsValid() const;

protected:
ProtoAttrMapHelper MutableAttrMap() override;
ConstProtoAttrMapHelper GetAttrMap() const override;

private:
bool GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const;
using AttrHolder::DelAttr;
using AttrHolder::GetAllAttrs;
using AttrHolder::GetAttr;
using AttrHolder::HasAttr;
using AttrHolder::SetAttr;

void Init();

// Create from proto obj
GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg);
friend class GeTensor;
friend class GeAttrValueImp;
friend class ModelSerializeImp;
friend class OnnxUtils;

GeIrProtoHelper<proto::TensorDescriptor> tensor_descriptor_;
// Reference from tensorDescriptor_, do not direct use
mutable GeShape __shape_;

void RefTo(const GeTensorDesc &tensorDesc) { tensor_descriptor_ = tensorDesc.tensor_descriptor_; }
GeShape &ShapeReference() const;
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor {
public:
GeTensor();
explicit GeTensor(const GeTensorDesc &tensorDesc);
explicit GeTensor(const GeTensorDesc &tensorDesc, const std::vector<uint8_t> &data);
explicit GeTensor(const GeTensorDesc &tensorDesc, const Buffer &data);
explicit GeTensor(const GeTensorDesc &tensorDesc, const uint8_t *data, size_t size);
explicit GeTensor(GeTensorDesc &&tensorDesc, std::vector<uint8_t> &&data);
~GeTensor() = default;

GeTensorDesc GetTensorDesc() const;
GeTensorDesc &MutableTensorDesc();
void SetTensorDesc(const GeTensorDesc &tensorDesc);

const Buffer GetData() const;
Buffer MutableData();
graphStatus SetData(std::vector<uint8_t> &&data);
graphStatus SetData(const std::vector<uint8_t> &data);
graphStatus SetData(const Buffer &data);
graphStatus SetData(const uint8_t *data, size_t size);

GeTensor Clone() const;

// Share value
GeTensor(const GeTensor &other);
// Share value
GeTensor &operator=(const GeTensor &other);

private:
friend class GeAttrValueImp;
friend class ModelSerializeImp;
friend class OnnxUtils;
// Create from proto obj
GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg);
GeIrProtoHelper<proto::TensorDef> tensor_def_;
// Reference from tensorDef_, do not direct use
mutable GeTensorDesc __desc_;
GeTensorDesc &DescReference() const;
};
} // namespace ge
#endif // INC_GRAPH_GE_TENSOR_H_

+ 0
- 134
metadef/inc/graph/graph_util.h View File

@@ -1,134 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_GRAPH_UTIL_H_
#define INC_GRAPH_GRAPH_UTIL_H_

#include <string>

#include "proto/om.pb.h"

namespace ge {
using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>;
bool HasOpAttr(const OpDef *opdef, std::string attr_name);
bool GetOpAttr(const std::string &key, int32_t *value, const OpDef *opdef);

static const char OP_TYPE_DATA[] = "Data";
static const char OP_TYPE_INPUT[] = "Input";
static const char ATTR_KEY_INPUT_FORMAT[] = "input_format";
static const char ATTR_KEY_OUTPUT_FORMAT[] = "output_format";
static const char OP_TYPE_ANN_DATA[] = "AnnData";
} // namespace ge

#if !defined(__ANDROID__) && !defined(ANDROID)
#include "toolchain/slog.h"
const char levelStr[4][8] = {"ERROR", "WARN", "INFO", "DEBUG"};
#else
#include <syslog.h>
#include <utils/Log.h>
const char levelStr[8][8] = {"EMERG", "ALERT", "CRIT", "ERROR", "WARNING", "NOTICE", "INFO", "DEBUG"};
#endif

#ifdef _MSC_VER
#define FUNC_NAME __FUNCTION__
#else
#define FUNC_NAME __PRETTY_FUNCTION__
#endif

#if !defined(__ANDROID__) && !defined(ANDROID)
#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) \
dlog_info(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__)
#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) \
dlog_warn(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__)
#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) \
dlog_error(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__)
#else
#define D_GRAPH_LOG(level, format, ...) \
do { \
{ \
fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \
__FILE__, __LINE__, ##__VA_ARGS__); \
syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \
##__VA_ARGS__); \
} \
} while (0)
#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__)
#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__)
#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__)
#endif

#if !defined(__ANDROID__) && !defined(ANDROID)
#define GRAPH_LOGI(...) D_GRAPH_LOGI(GRAPH_MOD_NAME, __VA_ARGS__)
#define GRAPH_LOGW(...) D_GRAPH_LOGW(GRAPH_MOD_NAME, __VA_ARGS__)
#define GRAPH_LOGE(...) D_GRAPH_LOGE(GRAPH_MOD_NAME, __VA_ARGS__)
#else

#define GRAPH_LOG(level, format, ...) \
do { \
{ \
fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \
__FILE__, __LINE__, ##__VA_ARGS__); \
syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \
##__VA_ARGS__); \
} \
} while (0)
#define GRAPH_LOGI(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__)
#define GRAPH_LOGW(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__)
#define GRAPH_LOGE(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__)
#endif

#define GRAPH_CHK_STATUS_RET_NOLOG(expr) \
do { \
const domi::graphStatus _status = (expr); \
if (_status != domi::GRAPH_SUCCESS) { \
return _status; \
} \
} while (0)

#define GRAPH_CHK_BOOL_RET_STATUS(expr, _status, ...) \
do { \
bool b = (expr); \
if (!b) { \
GRAPH_LOGE(__VA_ARGS__); \
return _status; \
} \
} while (0)

#define GRAPH_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \
{ \
bool b = (expr); \
if (!b) { \
exec_expr; \
} \
};

#define GRAPH_IF_BOOL_EXEC(expr, exec_expr) \
{ \
if (expr) { \
exec_expr; \
} \
}

#define GRAPH_RETURN_WITH_LOG_IF_ERROR(expr, ...) \
do { \
const ::domi::graphStatus _status = (expr); \
if (_status) { \
GRAPH_LOGE(__VA_ARGS__); \
return _status; \
} \
} while (0)

#endif // INC_GRAPH_GRAPH_UTIL_H_

+ 0
- 94
metadef/inc/graph/model.h View File

@@ -1,94 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_MODEL_H_
#define INC_GRAPH_MODEL_H_

#include <map>
#include <memory>
#include <string>
#include <vector>
#include "detail/attributes_holder.h"
#include "graph/ge_attr_value.h"
#include "graph/graph.h"

namespace ge {
using std::map;
using std::string;
using std::vector;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder {
public:
Model();

~Model() = default;

Model(const string &name, const string &custom_version);

string GetName() const;
void SetName(const string &name);

uint32_t GetVersion() const;

void SetVersion(uint32_t version) { version_ = version; }

std::string GetPlatformVersion() const;

void SetPlatformVersion(string version) { platform_version_ = version; }

Graph GetGraph() const;

void SetGraph(const Graph &graph);

void SetAttr(const ProtoAttrMapHelper &attrs);

using AttrHolder::GetAllAttrNames;
using AttrHolder::GetAllAttrs;
using AttrHolder::GetAttr;
using AttrHolder::HasAttr;
using AttrHolder::SetAttr;

graphStatus Save(Buffer &buffer, bool is_dump = false) const;

graphStatus SaveToFile(const string &file_name) const;
// Model will be rewrite
static graphStatus Load(const uint8_t *data, size_t len, Model &model);
graphStatus Load(ge::proto::ModelDef &model_def);
graphStatus LoadFromFile(const string &file_name);

bool IsValid() const;

protected:
ConstProtoAttrMapHelper GetAttrMap() const override;
ProtoAttrMapHelper MutableAttrMap() override;

private:
void Init();
ProtoAttrMapHelper attrs_;
friend class ModelSerializeImp;
friend class GraphDebugImp;
friend class OnnxUtils;
friend class ModelHelper;
friend class ModelBuilder;
string name_;
uint32_t version_;
std::string platform_version_{""};
Graph graph_;
};
} // namespace ge
using ModelPtr = std::shared_ptr<ge::Model>;

#endif // INC_GRAPH_MODEL_H_

+ 0
- 52
metadef/inc/graph/model_serialize.h View File

@@ -1,52 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_MODEL_SERIALIZE_H_
#define INC_GRAPH_MODEL_SERIALIZE_H_

#include <map>
#include <string>
#include "graph/buffer.h"
#include "graph/compute_graph.h"
#include "graph/model.h"

namespace ge {
class ModelSerialize {
public:
Buffer SerializeModel(const Model &model, bool is_dump = false);

Model UnserializeModel(const uint8_t *data, size_t len);
Model UnserializeModel(ge::proto::ModelDef &model_def);

Buffer SerializeGraph(const ComputeGraphPtr &graph);

ComputeGraphPtr UnserializeGraph(const uint8_t *data, size_t len);

Buffer SerializeOpDesc(const ConstOpDescPtr &opDesc);
OpDescPtr UnserializeOpDesc(const uint8_t *data, size_t len);

size_t GetSerializeModelSize(const Model &model);

private:
static std::map<std::string, GeAttrValue> &MutableTensorDescAttrMap(GeTensorDesc &tensorDesc);

static const std::map<std::string, GeAttrValue> &GetTensorDescAttrMap(const GeTensorDesc &tensorDesc);

friend class ModelSerializeImp;
friend class GraphDebugImp;
};
} // namespace ge
#endif // INC_GRAPH_MODEL_SERIALIZE_H_

+ 0
- 213
metadef/inc/graph/node.h View File

@@ -1,213 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_NODE_H_
#define INC_GRAPH_NODE_H_

#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <unordered_set>
#include "graph/ge_attr_value.h"
#include "utils/attr_utils.h"

#include "graph/op_desc.h"
#include "graph/range_vistor.h"

namespace ge {
class ComputeGraph;

using ComputeGraphPtr = std::shared_ptr<ComputeGraph>;

class Node;

using NodePtr = std::shared_ptr<Node>;
using ConstNodePtr = std::shared_ptr<const Node>;
using NodeRef = std::weak_ptr<Node>;

class Anchor;

using AnchorPtr = std::shared_ptr<Anchor>;

class InDataAnchor;

using InDataAnchorPtr = std::shared_ptr<InDataAnchor>;

class OutDataAnchor;

using OutDataAnchorPtr = std::shared_ptr<OutDataAnchor>;

class ControlAnchor;

using ControlAnchorPtr = std::shared_ptr<ControlAnchor>;

class InControlAnchor;

using InControlAnchorPtr = std::shared_ptr<InControlAnchor>;

class OutControlAnchor;

using OutControlAnchorPtr = std::shared_ptr<OutControlAnchor>;

using OpDescPtr = std::shared_ptr<OpDesc>;

using ConstNode = const Node;

typedef std::vector<std::multimap<std::string, ge::AnchorPtr>> kFusionDataFlowVec_t;

// Node is a component of ComputeGraph
class Node : public std::enable_shared_from_this<Node> {
friend class ComputeGraph;
friend class ModelSerializeImp;

public:
template <class T>
using Vistor = RangeVistor<T, std::shared_ptr<ConstNode>>;
~Node();
Node(const Node &) = delete;
Node &operator=(const Node &) = delete;
bool operator==(const Node &r_node) const;

protected:
Node() = default;
Node(const OpDescPtr &op, const ComputeGraphPtr &ownerGraph);

public:
graphStatus Init();

std::string GetName() const;
std::string GetType() const;

ComputeGraphPtr GetOwnerComputeGraph() const;
graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph);

Vistor<InDataAnchorPtr> GetAllInDataAnchors() const;
Vistor<OutDataAnchorPtr> GetAllOutDataAnchors() const;
uint32_t GetAllInDataAnchorsSize() const;
uint32_t GetAllOutDataAnchorsSize() const;
Vistor<AnchorPtr> GetAllOutAnchors() const;
Vistor<AnchorPtr> GetAllInAnchors() const;
InDataAnchorPtr GetInDataAnchor(int idx) const;
OutDataAnchorPtr GetOutDataAnchor(int idx) const;
InControlAnchorPtr GetInControlAnchor() const;
OutControlAnchorPtr GetOutControlAnchor() const;
Vistor<NodePtr> GetInNodes() const;
Vistor<NodePtr> GetOutNodes() const;
AnchorPtr GetInAnchor(int idx) const;
AnchorPtr GetOutAnchor(int idx) const;

bool IsAllInNodesSeen(std::unordered_set<Node *> &nodes_seen) const;

// All in Data nodes
Vistor<NodePtr> GetInDataNodes() const;
// All in Control nodes
Vistor<NodePtr> GetInControlNodes() const;
// GetInAllNodes = InDataNodes + InControlNodes
Vistor<NodePtr> GetInAllNodes() const;

// All out Data nodes
Vistor<NodePtr> GetOutDataNodes() const;
uint32_t GetOutDataNodesSize() const;
// All out Control nodes
Vistor<NodePtr> GetOutControlNodes() const;
// GetOutAllNodes = OutDataNodes + InControlNodes
Vistor<NodePtr> GetOutAllNodes() const;

// Get all in data nodes and its out-anchor
Vistor<std::pair<NodePtr, OutDataAnchorPtr>> GetInDataNodesAndAnchors() const;

// Get all out data nodes and its in-anchor
Vistor<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesAndAnchors() const;

graphStatus InferShapeAndType() const;
graphStatus Verify() const;

graphStatus InferOriginFormat() const;

OpDescPtr GetOpDesc() const;

graphStatus UpdateOpDesc(const OpDescPtr &op);

graphStatus AddLinkFrom(const NodePtr &input_node);

graphStatus AddLinkFrom(const uint32_t &index, NodePtr input_node);

graphStatus AddLinkFrom(const string &name, NodePtr input_node);

graphStatus AddLinkFromForParse(const NodePtr &input_node);

void AddSendEventId(uint32_t event_id) { send_event_id_list_.push_back(event_id); }

void AddRecvEventId(uint32_t event_id) { recv_event_id_list_.push_back(event_id); }

const std::vector<uint32_t> &GetSendEventIdList() const { return send_event_id_list_; }

const std::vector<uint32_t> &GetRecvEventIdList() const { return recv_event_id_list_; }
void GetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) {
fusion_input_list = fusion_input_dataflow_list_;
}

void GetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) {
fusion_output_list = fusion_output_dataflow_list_;
}

void SetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) {
fusion_input_dataflow_list_ = fusion_input_list;
}

void SetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) {
fusion_output_dataflow_list_ = fusion_output_list;
}

bool GetHostNode() const { return host_node_; }
void SetHostNode(bool is_host) { host_node_ = is_host; }

void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; }

NodePtr GetOrigNode() { return orig_node_; }

private:
bool NodeMembersAreEqual(const Node &r_node) const;
bool NodeAttrsAreEqual(const Node &r_node) const;
bool NodeInConnectsAreEqual(const Node &r_node) const;
bool NodeOutConnectsAreEqual(const Node &r_node) const;
bool NodeAnchorIsEqual(const AnchorPtr &l_anchor, const AnchorPtr &r_anchor, size_t i) const;
OpDescPtr op_;
std::weak_ptr<ComputeGraph> owner_graph_;
vector<InDataAnchorPtr> in_data_anchors_;
vector<OutDataAnchorPtr> out_data_anchors_;
InControlAnchorPtr in_control_anchor_;
OutControlAnchorPtr out_control_anchor_;
map<string, GeAttrValue> attrs_; // lint !e1073
bool has_init_{false};
bool host_node_{false};
bool anchor_status_updated_{false};
std::vector<uint32_t> send_event_id_list_;
std::vector<uint32_t> recv_event_id_list_;

kFusionDataFlowVec_t fusion_input_dataflow_list_;
kFusionDataFlowVec_t fusion_output_dataflow_list_;

NodePtr orig_node_;
friend class NodeUtils;
friend class OnnxUtils;
friend class TuningUtils;
};
} // namespace ge

#endif // INC_GRAPH_NODE_H_

+ 0
- 328
metadef/inc/graph/op_desc.h View File

@@ -1,328 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_OP_DESC_H_
#define INC_GRAPH_OP_DESC_H_

#include <functional>
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "detail/attributes_holder.h"
#include "graph/range_vistor.h"

#define DYNAMIN_INPUT_NAME(name, index) (((name)) + std::to_string((index)))
#define DYNAMIN_OUTPUT_NAME(name, index) (((name)) + std::to_string((index)))
namespace ge {
using std::map;
using std::pair;
using std::shared_ptr;
using std::string;
using std::vector;

class Operator;
class GeTensorDesc;

using GeTensorDescPtr = shared_ptr<GeTensorDesc>;
using ConstGeTensorDescPtr = shared_ptr<const GeTensorDesc>;

class OpDesc;

using OpDescPtr = shared_ptr<OpDesc>;
using ConstOpDescPtr = shared_ptr<const OpDesc>;

class GeAttrValue;

using ConstOpDesc = const OpDesc;

enum SubgraphType { kStatic, kDynamic, kSubgraphTypeEnd };

class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {
public:
template <class T>
using Vistor = RangeVistor<T, shared_ptr<ConstOpDesc>>;

friend class GraphBuilderImpl;

friend class OperatorImpl;

OpDesc(const string &name, const string &type);

OpDesc();

~OpDesc();

bool operator==(const OpDesc &r_op_desc) const;

string GetName() const;

void SetName(const string &name);

string GetType() const;

void SetType(const string &type);

graphStatus AddInputDesc(const GeTensorDesc &input_desc);

graphStatus AddInputDesc(const string &name, const GeTensorDesc &input_desc);

graphStatus AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc);

graphStatus AddInputDescForward(const string &name, const unsigned int num);

graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index);

graphStatus AddOutputDescMiddle(const string &name, const unsigned int num, size_t index);

graphStatus AddOutputDescForward(const string &name, const unsigned int num);

graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc);

graphStatus UpdateInputDesc(uint32_t index, const GeTensorDesc &tensor_desc);

graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc);

bool InputIsSet(const string &name) const;

GeTensorDesc GetInputDesc(uint32_t index) const;

GeTensorDesc GetInputDesc(const string &name) const;

Vistor<string> GetAllInputNames() const;

GeTensorDescPtr MutableInputDesc(uint32_t index) const;

GeTensorDescPtr MutableInputDesc(const string &name) const;

Vistor<GeTensorDesc> GetAllInputsDesc() const;

Vistor<GeTensorDescPtr> GetAllInputsDescPtr() const;

size_t GetInputsSize() const;

size_t GetAllInputsSize() const;

graphStatus AddOutputDesc(const GeTensorDesc &output_desc);

graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc);

graphStatus UpdateOutputDesc(uint32_t index, const GeTensorDesc &tensor_desc);

graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc);

GeTensorDesc GetOutputDesc(uint32_t index) const;

GeTensorDesc GetOutputDesc(const string &name) const;

GeTensorDescPtr MutableOutputDesc(uint32_t index) const;

GeTensorDescPtr MutableOutputDesc(const string &name) const;

uint32_t GetAllOutputsDescSize() const;

Vistor<GeTensorDesc> GetAllOutputsDesc() const;

Vistor<GeTensorDescPtr> GetAllOutputsDescPtr() const;

size_t GetOutputsSize() const;

ConstGeTensorDescPtr GetOutputDescPtr(uint32_t index) const;

ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const;

ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const;

ConstGeTensorDescPtr GetInputDescPtr(const string &name) const;

graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true);

graphStatus AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index);

graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true);

bool IsOptionalInput(const string &name) const;

bool IsOptionalInput(uint32_t index) const;

std::map<string, uint32_t> GetAllInputName() const;

std::map<string, uint32_t> GetAllOutputName();

bool UpdateInputName(std::map<string, uint32_t> inputNameIdx);

bool UpdateOutputName(std::map<string, uint32_t> outputNameIdx);

void AddInferFunc(const std::function<graphStatus(Operator &)> &func);

std::function<graphStatus(Operator &)> GetInferFunc() const;

graphStatus InferShapeAndType();

void AddInferFormatFunc(const std::function<graphStatus(Operator &)> &func);

std::function<graphStatus(Operator &)> GetInferFormatFunc() const;

graphStatus DefaultInferFormat();

std::function<graphStatus(Operator &)> GetVerifyFunc() const;

void AddVerifierFunc(const std::function<graphStatus(Operator &)> &func);

graphStatus CallInferFormatFunc(Operator &op);

graphStatus OpVerify();

graphStatus CommonVerify() const;

graphStatus AddRegisterInputName(const string &name);

graphStatus AddRegisterOutputName(const string &name);

vector<string> GetRegisterInputName() const;

vector<string> GetRegisterOutputName() const;

using AttrHolder::AddRequiredAttr;
using AttrHolder::DelAttr;
using AttrHolder::GetAllAttrNames;
using AttrHolder::GetAllAttrs;
using AttrHolder::GetAttr;
using AttrHolder::HasAttr;
using AttrHolder::SetAttr;

void SetId(int64_t id);
int64_t GetId() const;
void SetStreamId(int64_t stream_id);
int64_t GetStreamId() const;
void SetInputName(const vector<string> &input_name);
vector<string> GetInputName() const;
void SetSrcName(const vector<string> &src_name);
vector<string> GetSrcName() const;
void SetSrcIndex(const vector<int64_t> &src_index);
vector<int64_t> GetSrcIndex() const;
void SetInputOffset(const vector<int64_t> &input);
vector<int64_t> GetInputOffset() const;
void SetOutputOffset(const vector<int64_t> &input);
vector<int64_t> GetOutputOffset() const;
void SetDstName(const vector<string> &dst_name);
vector<string> GetDstName() const;
void SetDstIndex(const vector<int64_t> &dst_index);
vector<int64_t> GetDstIndex() const;
void SetWorkspace(const vector<int64_t> &workspace);
vector<int64_t> GetWorkspace() const;
void SetWorkspaceBytes(const vector<int64_t> &workspace_bytes);
vector<int64_t> GetWorkspaceBytes() const;
void SetIsInputConst(const vector<bool> &is_input_const);
vector<bool> GetIsInputConst() const;

void SetOpInferDepends(const vector<string> &depend_names);
vector<string> GetOpInferDepends() const;

string GetInputNameByIndex(uint32_t index) const;

int GetInputIndexByName(const string &name) const;

string GetOutputNameByIndex(uint32_t index) const;

int GetOutputIndexByName(const string &name) const;

graphStatus RestoreInputNameIdx(const string &name, const int &index);

graphStatus RestoreOutputNameIdx(const string &name, const int &index);

graphStatus CallInferFunc(Operator &op);

void SetOpKernelLibName(const std::string &name);

std::string GetOpKernelLibName() const;

void SetOpEngineName(const std::string &name);

std::string GetOpEngineName() const;

void RegisterSubgraphIrName(const std::string &name, SubgraphType type);
const std::map<std::string, SubgraphType> &GetSubgraphIrNames() const;
SubgraphType GetSubgraphTypeByIrName(const std::string &name) const;

graphStatus AddSubgraphName(const std::string &name);
const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const;

std::string GetSubgraphInstanceName(uint32_t index) const;
const std::vector<std::string> &GetSubgraphInstanceNames() const;
/// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`,
/// because this kind of functions will only append a new subgraph instance name
/// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`.
/// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first.
/// \param index
/// \param name
/// \return
graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name);
void RemoveSubgraphInstanceName(const std::string &name);

graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const;

protected:
ProtoAttrMapHelper MutableAttrMap() override;
ConstProtoAttrMapHelper GetAttrMap() const override;

private:
OpDesc(const ProtoMsgOwner &proto_msg_owner, ge::proto::OpDef *op_def);
bool OpDescMembersAreEqual(const OpDesc &r_op_desc) const;
bool OpDescAttrsAreEqual(const OpDesc &r_op_desc) const;
bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const;

GeIrProtoHelper<ge::proto::OpDef> op_def_;
std::vector<std::string> subgraph_instance_names_;

// subgraph names to index, for a `if` operator:
// then_branch: 0
// else_branch: 1
// or for a `case` node:
// branches0: 0
// branches1: 1
// branches2: 2
std::map<std::string, uint32_t> subgraph_names_to_index_;

// subgraph ir names to type, for a `if` operator:
// then_branch: static
// else_branch: static
// or for a `case` op:
// branches: dynamic
std::map<std::string, SubgraphType> subgraph_ir_names_to_type_;

vector<GeTensorDescPtr> inputs_desc_{};
map<string, uint32_t> input_name_idx_{};
vector<string> register_input_name_{};
std::unordered_set<string> optional_input_names_{};
vector<GeTensorDescPtr> outputs_desc_{};
map<string, uint32_t> output_name_idx_{};
vector<string> register_output_name_{};
std::function<graphStatus(Operator &)> infer_func_ = nullptr;
std::function<graphStatus(Operator &)> infer_format_func_ = nullptr;
std::function<graphStatus(Operator &)> verifier_func_ = nullptr;
string op_kernel_lib_name_;
string engine_name_;
friend class OpDescUtils;
friend class ModelSerializeImp;
friend class AttrUtils;
friend class GeAttrValueImp;
friend class OnnxUtils;
};
} // namespace ge
#endif // INC_GRAPH_OP_DESC_H_

+ 0
- 48
metadef/inc/graph/op_kernel_bin.h View File

@@ -1,48 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_OP_KERNEL_BIN_H_
#define INC_GRAPH_OP_KERNEL_BIN_H_

#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace ge {
class OpKernelBin {
public:
OpKernelBin(std::string name, std::vector<char> &&data) : name_(std::move(name)), data_(std::move(data)) {}

~OpKernelBin() = default;

const std::string &GetName() const { return name_; }
const uint8_t *GetBinData() const { return (const uint8_t *)data_.data(); }
size_t GetBinDataSize() const { return data_.size(); }
OpKernelBin(const OpKernelBin &) = delete;
const OpKernelBin &operator=(const OpKernelBin &) = delete;

private:
std::string name_;
std::vector<char> data_;
};

using OpKernelBinPtr = std::shared_ptr<OpKernelBin>;
const char *const OP_EXTATTR_NAME_TBE_KERNEL = "tbeKernel";
const char *const OP_EXTATTR_CUSTAICPU_KERNEL = "cust_aicpu_kernel";
} // namespace ge

#endif // INC_GRAPH_OP_KERNEL_BIN_H_

+ 0
- 56
metadef/inc/graph/operator_factory_impl.h View File

@@ -1,56 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_OPERATOR_FACTORY_IMPL_H_
#define INC_GRAPH_OPERATOR_FACTORY_IMPL_H_

#include <map>
#include <memory>
#include <string>
#include <vector>
#include "graph/operator_factory.h"

namespace ge {
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl {
public:
static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type);

static graphStatus GetOpsTypeList(std::vector<std::string> &all_ops);

static bool IsExistOp(const string &operator_type);

static InferShapeFunc GetInferShapeFunc(const std::string &operator_type);

static InferFormatFunc GetInferFormatFunc(const std::string &operator_type);

static VerifyFunc GetVerifyFunc(const std::string &operator_type);

static graphStatus RegisterOperatorCreator(const std::string &operator_type, OpCreator const &op_creator);

static graphStatus RegisterInferShapeFunc(const std::string &operator_type, InferShapeFunc const infer_shape_func);

static graphStatus RegisterInferFormatFunc(const std::string &operator_type, InferFormatFunc const infer_format_func);

static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func);

static shared_ptr<std::map<string, OpCreator>> operator_creators_;
static shared_ptr<std::map<string, InferShapeFunc>> operator_infershape_funcs_;
static shared_ptr<std::map<string, InferFormatFunc>> operator_inferformat_funcs_;
static shared_ptr<std::map<string, VerifyFunc>> operator_verify_funcs_;
};
} // namespace ge

#endif // INC_GRAPH_OPERATOR_FACTORY_IMPL_H_

+ 0
- 46
metadef/inc/graph/opsproto_manager.h View File

@@ -1,46 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_OPSPROTO_MANAGER_H_
#define INC_GRAPH_OPSPROTO_MANAGER_H_

#include <dirent.h>
#include <dlfcn.h>
#include <string.h>
#include <map>
#include <string>
#include <vector>
#include <mutex>

namespace ge {
class OpsProtoManager {
public:
static OpsProtoManager *Instance();

bool Initialize(const std::map<std::string, std::string> &options);
void Finalize();

private:
void LoadOpsProtoPluginSo(std::string &path);

std::string pluginPath_;
std::vector<void *> handles_;
bool is_init_ = false;
std::mutex mutex_;
};
} // namespace ge

#endif // INC_GRAPH_OPSPROTO_MANAGER_H_

+ 0
- 53
metadef/inc/graph/range_vistor.h View File

@@ -1,53 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_RANGE_VISTOR_H_
#define INC_GRAPH_RANGE_VISTOR_H_

#include <vector>

template <class E, class O>
class RangeVistor {
public:
using Iterator = typename std::vector<E>::iterator;
using ConstIterator = typename std::vector<E>::const_iterator;

RangeVistor(O owner, const std::vector<E> &vs) : owner_(owner), elements_(vs) {}

~RangeVistor() {}

Iterator begin() { return elements_.begin(); }

Iterator end() { return elements_.end(); }

ConstIterator begin() const { return elements_.begin(); }

ConstIterator end() const { return elements_.end(); }

std::size_t size() const { return elements_.size(); }

bool empty() const { return elements_.empty(); }

E &at(std::size_t index) { return elements_.at(index); }

const E &at(std::size_t index) const { return elements_.at(index); }

private:
O owner_;
std::vector<E> elements_;
};

#endif // INC_GRAPH_RANGE_VISTOR_H_

+ 0
- 79
metadef/inc/graph/ref_relation.h View File

@@ -1,79 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMMON_GRAPH_REF_RELATION_H_
#define COMMON_GRAPH_REF_RELATION_H_

#include <deque>
#include <string>
#include <unordered_map>
#include <vector>

#include "graph/compute_graph.h"
#include "graph/types.h"
#include "graph/ge_error_codes.h"
#include "node.h"

namespace ge {
enum InOutFlag {
NODE_IN = 0, // input flag
NODE_OUT = 1, // output flag
};

struct RefCell {
std::string node_name;
ge::NodePtr node = nullptr;
InOutFlag in_out = NODE_IN;
int in_out_idx = 0;

bool operator==(const RefCell &c) const {
return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx;
}

RefCell() = default;
RefCell(std::string name, ge::NodePtr node_ptr, InOutFlag in_out_flag, int idx) {
node_name = name;
node = node_ptr;
in_out = in_out_flag;
in_out_idx = idx;
};
~RefCell() = default;
};

struct RefCellHash {
size_t operator()(const RefCell &c) const {
unsigned long number = reinterpret_cast<unsigned long>(reinterpret_cast<uintptr_t>(c.node.get()));
string tmp = c.node_name + std::to_string(c.in_out) + std::to_string(c.in_out_idx) + std::to_string(number);
return std::hash<string>()(tmp);
}
};

class RefRelations {
public:
graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set<RefCell, RefCellHash> &result);
graphStatus BuildRefRelations(ge::ComputeGraph &root_graph);
graphStatus Clear();

RefRelations();
~RefRelations() = default;

public:
class Impl;
std::shared_ptr<Impl> impl_ = nullptr;
};

} // namespace ge
#endif // COMMON_GRAPH_REF_RELATION_H_

+ 0
- 46
metadef/inc/graph/runtime_inference_context.h View File

@@ -1,46 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_
#define INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_

#include <map>
#include <memory>
#include <mutex>
#include <vector>
#include "external/graph/ge_error_codes.h"
#include "external/graph/tensor.h"

namespace ge {
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext {
public:
static graphStatus GetContext(const std::string &context_id, RuntimeInferenceContext **ctx);
static graphStatus CreateContext(const std::string &context_id);
static void DestroyContext(const std::string &context_id);

graphStatus SetTensor(int64_t node_id, int output_id, Tensor &&tensor);
graphStatus GetTensor(int64_t node_id, int output_id, Tensor &tensor);

private:
std::map<int64_t, std::vector<Tensor>> tensors_;
std::mutex mu_;

static std::map<std::string, std::unique_ptr<RuntimeInferenceContext>> contexts_;
static std::mutex ctx_mu_;
};
} // namespace ge

#endif // INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_

+ 0
- 40
metadef/inc/graph/shape_refiner.h View File

@@ -1,40 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_SHAPE_REFINER_H_
#define INC_GRAPH_SHAPE_REFINER_H_

#include <string>
#include "external/graph/inference_context.h"

#include "external/graph/ge_error_codes.h"
#include "graph/node.h"

namespace ge {
// ShapeRefiner performs shape inference for compute graphs
class ShapeRefiner {
public:
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph);
static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph);
static graphStatus InferShapeAndType(const NodePtr &node);
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op);
static void ClearContextMap();

private:
static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase);
};
} // namespace ge
#endif // INC_GRAPH_SHAPE_REFINER_H_

+ 0
- 130
metadef/inc/graph/tuning_utils.h View File

@@ -1,130 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MAIN_TUNING_UTILS_H
#define MAIN_TUNING_UTILS_H

#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
#include <cstring>
#include <fstream>
#include <iomanip>
#include <queue>
#include <mutex>

#include <graph/anchor.h>
#include <graph/detail/attributes_holder.h>
#include <graph/ge_tensor.h>
#include <graph/graph.h>
#include <graph/model.h>
#include <graph/node.h>
#include <graph/utils/graph_utils.h>
#include <graph/utils/type_utils.h>

#include "framework/common/debug/ge_log.h"
#include "utils/attr_utils.h"
#include "utils/node_utils.h"
#include "external/ge/ge_api_types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
namespace ge {
// Configure build mode, default value is "normal"
const char *const BUILD_MODE = "ge.buildMode";
const char *const BUILD_STEP = "ge.buildStep";
// Configure tuning path
const char *const TUNING_PATH = "ge.tuningPath";
// for interface: aclgrphBuildModel
const std::set<std::string> ir_builder_supported_options_for_lx_fusion = {BUILD_MODE, BUILD_STEP, TUNING_PATH};

// Build model
const char *const BUILD_MODE_NORMAL = "normal";
const char *const BUILD_MODE_TUNING = "tuning";
const char *const BUILD_MODE_BASELINE = "baseline";
const std::set<std::string> build_mode_options = {BUILD_MODE_NORMAL, BUILD_MODE_TUNING, BUILD_MODE_BASELINE};

// Build step
const char *const BUILD_STEP_BEFORE_UB_MATCH = "before_ub_match";
const char *const BUILD_STEP_AFTER_UB_MATCH = "after_ub_match";
const char *const BUILD_STEP_AFTER_BUILDER = "after_builder";
const char *const BUILD_STEP_AFTER_BUILDER_SUB = "after_builder_sub";
const char *const BUILD_STEP_AFTER_MERGE = "after_merge";
const std::set<std::string> build_step_options = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_UB_MATCH,
BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB,
BUILD_STEP_AFTER_MERGE};

using SubgraphCreateOutNode = std::unordered_map<ComputeGraphPtr, NodePtr>;
using NodetoNodeMap = std::unordered_map<NodePtr, NodePtr>;
using NodeSet = std::set<NodePtr>;
using NodeNametoNodeNameMap = std::unordered_map<std::string, std::string>;
using NodetoNodeNameMap = std::unordered_map<NodePtr, std::string>;
class TuningUtils {
public:
TuningUtils() = default;
~TuningUtils() = default;
// Dump all the subgraphs and modify
// the subgraphs in them to be executable subgraphs if exe_flag is true
// `tuning_path` means path to save the graphs
static graphStatus ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_subgraphs,
std::vector<ComputeGraphPtr> non_tuning_subgraphs = {}, bool exe_flag = false,
const std::string &path = "", const std::string &user_path = "");
// Recovery `graph` from graph dump files configured in options
static graphStatus ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph);

private:
// part 1
struct HelpInfo {
int64_t index;
bool exe_flag;
bool is_tuning_graph;
const std::string &path;
const std::string &user_path;
};
static graphStatus MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info);
static graphStatus HandlePld(NodePtr &node);
static graphStatus HandleEnd(NodePtr &node);
static graphStatus ChangePld2Data(NodePtr &node, NodePtr &data_node);
static graphStatus ChangeEnd2NetOutput(NodePtr &node, NodePtr &out_node);
static graphStatus LinkEnd2NetOutput(NodePtr &node, NodePtr &out_node);
static graphStatus CreateDataNode(NodePtr &node, NodePtr &data_node);
static graphStatus CreateNetOutput(NodePtr &node, NodePtr &out_node);
static graphStatus AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node);
static graphStatus AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node);
static void DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path);

static SubgraphCreateOutNode create_output_;
// part 2
static graphStatus MergeAllSubGraph(std::vector<ComputeGraphPtr> &graphs, ComputeGraphPtr &graph);
static graphStatus MergeSubGraph(ComputeGraphPtr &graph);
// Deletes new data and output nodes added by call `MakeExeGraph()` func in part 1
static graphStatus RemoveDataNetoutputEdge(ComputeGraphPtr &graph);
static graphStatus GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor,
AnchorPtr &src_out_anchor);
static NodeNametoNodeNameMap data_2_netoutput_;
static NodetoNodeNameMap data_node_2_netoutput_;
static NodetoNodeMap data_node_2_netoutput_node_;
static NodeSet netoutput_nodes_;
static NodeSet merged_graph_nodes_;
static std::mutex mutex_;
// for debug
static std::string PrintCheckLog();
static std::string GetNodeNameByAnchor(const Anchor *anchor);
};
} // namespace ge
#endif // MAIN_TUNING_UTILS_H

+ 0
- 133
metadef/inc/graph/usr_types.h View File

@@ -1,133 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_USR_TYPES_H_
#define INC_GRAPH_USR_TYPES_H_

#include <atomic>
#include <memory>
#include <vector>
namespace ge {
#define USR_TYPE_DEC(type, name) \
inline void set_##name(const type &value) { name = value; } \
type *mutable_##name() { return &name; }

#define USR_TYPE_HAS_DEC(type, name) \
inline void set_##name(const type &value) { name = value; } \
\
private: \
bool has_mutable_##name{false}; \
\
public: \
bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \
type *mutable_##name() { \
has_mutable_##name = true; \
return &name; \
}

#define USR_TYPE_BYTES_DEC(name) \
inline void clear_##name() { name.clear(); } \
inline void set_##name(const void *value, size_t size) { \
name.assign(reinterpret_cast<uint8_t *>(const_cast<void *>(value)), \
reinterpret_cast<uint8_t *>(const_cast<void *>(value)) + size); \
}

enum UsrQuantizeScaleType { USR_VECTOR_SCALE = 0, USR_SCALAR_SCALE = 1 };
enum UsrQuantizeScaleMode { USR_NORMAL_MODE = 0, USR_SQRT_MODE = 1 };
enum UsrQuantizeAlgorithm {
USR_NON_OFFSET_ALGO = 0,
USR_HALF_OFFSET_ALGO = 1,
USR_ALL_OFFSET_ALGO = 2,
};

struct UsrQuantizeFactor {
public:
// QuantizeScaleMode scale_mode;
UsrQuantizeScaleMode scale_mode{USR_NORMAL_MODE};
std::vector<uint8_t> scale_value;
int64_t scale_offset{0};
std::vector<uint8_t> offset_data_value;
int64_t offset_data_offset{0};
std::vector<uint8_t> offset_weight_value;
int64_t offset_weight_offset{0};
std::vector<uint8_t> offset_pad_value;
int64_t offset_pad_offset{0};

USR_TYPE_DEC(UsrQuantizeScaleMode, scale_mode);
USR_TYPE_BYTES_DEC(scale_value);

USR_TYPE_DEC(int64_t, scale_offset);
USR_TYPE_BYTES_DEC(offset_data_value);
USR_TYPE_DEC(int64_t, offset_data_offset);

USR_TYPE_BYTES_DEC(offset_weight_value);
USR_TYPE_DEC(int64_t, offset_weight_offset);
USR_TYPE_BYTES_DEC(offset_pad_value);
USR_TYPE_DEC(int64_t, offset_pad_offset);
};

static inline bool QuantizeFactorHasData(const UsrQuantizeFactor &factor) {
return factor.scale_value.size() > 0 || factor.offset_data_value.size() > 0 ||
factor.offset_weight_value.size() > 0 || factor.offset_pad_value.size() > 0;
}

struct UsrQuantizeCalcFactor {
public:
std::vector<uint8_t> offsetw;
int64_t offsetw_offset{0};
std::vector<uint8_t> offsetd;
int64_t offsetd_offset{0};
std::vector<uint8_t> scalereq;
int64_t scaledreq_offset{0};
std::vector<uint8_t> offsetdnext;
int64_t offsetdnext_offset{0};

USR_TYPE_BYTES_DEC(offsetw);
USR_TYPE_DEC(int64_t, offsetw_offset);
USR_TYPE_BYTES_DEC(offsetd);
USR_TYPE_DEC(int64_t, offsetd_offset);
USR_TYPE_BYTES_DEC(scalereq);
USR_TYPE_DEC(int64_t, scaledreq_offset);
USR_TYPE_BYTES_DEC(offsetdnext);
USR_TYPE_DEC(int64_t, offsetdnext_offset);
};

static inline bool QuantizeFactorHasData(const UsrQuantizeCalcFactor &factor) {
return factor.offsetw.size() > 0 || factor.offsetd.size() > 0 || factor.scalereq.size() > 0 ||
factor.offsetdnext.size() > 0;
}

struct UsrQuantizeFactorParams {
UsrQuantizeAlgorithm quantize_algo{USR_NON_OFFSET_ALGO};
UsrQuantizeScaleType scale_type{USR_VECTOR_SCALE};
UsrQuantizeFactor quantize_param;
UsrQuantizeFactor dequantize_param;
UsrQuantizeFactor requantize_param;
UsrQuantizeCalcFactor quantizecalc_param;
USR_TYPE_DEC(UsrQuantizeAlgorithm, quantize_algo);
USR_TYPE_DEC(UsrQuantizeScaleType, scale_type);
USR_TYPE_HAS_DEC(UsrQuantizeFactor, quantize_param);
USR_TYPE_HAS_DEC(UsrQuantizeFactor, dequantize_param);
USR_TYPE_HAS_DEC(UsrQuantizeFactor, requantize_param);
USR_TYPE_HAS_DEC(UsrQuantizeCalcFactor, quantizecalc_param);
};

#undef USR_TYPE_DEC
#undef USR_TYPE_HAS_DEC
#undef USR_TYPE_BYTES_DEC
} // namespace ge

#endif // INC_GRAPH_USR_TYPES_H_

+ 0
- 45
metadef/inc/graph/utils/anchor_utils.h View File

@@ -1,45 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_UTILS_ANCHOR_UTILS_H_
#define INC_GRAPH_UTILS_ANCHOR_UTILS_H_

#include "graph/anchor.h"
#include "graph/node.h"

namespace ge {
class AnchorUtils {
public:
// Get anchor format
static Format GetFormat(const DataAnchorPtr &dataAnchor);

// Set anchor format
static graphStatus SetFormat(const DataAnchorPtr &dataAnchor, Format dataFormat);

// Get anchor status
static AnchorStatus GetStatus(const DataAnchorPtr &dataAnchor);

// Set anchor status
static graphStatus SetStatus(const DataAnchorPtr &dataAnchor, AnchorStatus anchorStatus);

static bool HasControlEdge(const AnchorPtr &anchor);

static bool IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst);

static int GetIdx(const AnchorPtr &anchor);
};
} // namespace ge
#endif // INC_GRAPH_UTILS_ANCHOR_UTILS_H_

+ 0
- 150
metadef/inc/graph/utils/attr_utils.h View File

@@ -1,150 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_UTILS_ATTR_UTILS_H_
#define INC_GRAPH_UTILS_ATTR_UTILS_H_

#include <memory>
#include <string>
#include <vector>
#include "graph/detail/attributes_holder.h"
#include "graph/ge_attr_value.h"
#include "graph/types.h"

namespace ge {
class OpDesc;
using OpDescPtr = std::shared_ptr<OpDesc>;
using ConstOpDescPtr = std::shared_ptr<const OpDesc>;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils {
public:
class ConstAttrHolderAdapter;
class AttrHolderAdapter;
// Set
static bool HasAttr(ConstAttrHolderAdapter &&obj, const string &name);

static bool SetInt(AttrHolderAdapter &&obj, const string &name, const int64_t &value);
static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<int64_t> &value);
static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<uint32_t> &value);
static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<int32_t> &value);
static bool SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list<int64_t> &&value);

static bool SetFloat(AttrHolderAdapter &&obj, const string &name, const float &value);
static bool SetListFloat(AttrHolderAdapter &&obj, const string &name, const vector<float> &value);
static bool SetBool(AttrHolderAdapter &&obj, const string &name, const bool &value);
static bool SetListBool(AttrHolderAdapter &&obj, const string &name, const vector<bool> &value);
static bool SetStr(AttrHolderAdapter &&obj, const string &name, const string &value);
static bool SetListStr(AttrHolderAdapter &&obj, const string &name, const vector<string> &value);
static bool SetTensorDesc(AttrHolderAdapter &&obj, const string &name, const GeTensorDesc &value);
static bool SetListTensorDesc(AttrHolderAdapter &&obj, const string &name, const vector<GeTensorDesc> &value);
static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensorPtr &value);
static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const ConstGeTensorPtr &value);
static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensor &value);
static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<GeTensorPtr> &value);
static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<ConstGeTensorPtr> &value);
static bool SetListTensor(AttrHolderAdapter &&obj, const string &name,
std::initializer_list<ConstGeTensorPtr> &&value);
static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<GeTensor> &value);
static bool SetGraph(AttrHolderAdapter &&obj, const string &name, const ComputeGraphPtr &value);
static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector<ComputeGraphPtr> &value);
static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value);
static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector<GeAttrValue::BYTES> &value);
static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value);
static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name,
const vector<GeAttrValue::NAMED_ATTRS> &value);
static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<ConstOpDescPtr> &value);
static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<OpDescPtr> &value);

// Get
static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int64_t &value);
static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int32_t &value);
static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, uint32_t &value);
static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<int64_t> &value);
static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<int32_t> &value);
static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<uint32_t> &value);
static bool GetFloat(ConstAttrHolderAdapter &&obj, const string &name, float &value);
static bool GetListFloat(ConstAttrHolderAdapter &&obj, const string &name, vector<float> &value);
static bool GetBool(ConstAttrHolderAdapter &&obj, const string &name, bool &value);
static bool GetListBool(ConstAttrHolderAdapter &&obj, const string &name, vector<bool> &value);
static bool GetStr(ConstAttrHolderAdapter &&obj, const string &name, string &value);
static bool GetListStr(ConstAttrHolderAdapter &&obj, const string &name, vector<string> &value);
static bool GetTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, GeTensorDesc &value);
static bool GetListTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<GeTensorDesc> &value);
static bool GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value);
static bool MutableTensor(AttrHolderAdapter &&obj, const string &name, GeTensorPtr &value);
static bool GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector<ConstGeTensorPtr> &value);
static bool MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector<GeTensorPtr> &value);
static bool GetGraph(ConstAttrHolderAdapter &&obj, const string &name, ComputeGraphPtr &value);
static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector<ComputeGraphPtr> &value);
static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value);
static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<GeAttrValue::BYTES> &value);
static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value);
static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name,
vector<GeAttrValue::NAMED_ATTRS> &value);
static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<OpDescPtr> &value);
// Value will be moved
static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer);
static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer);
// Value will be moved
static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer);
static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer);

static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value);
static bool GetListListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<vector<int64_t>> &value);

static bool SetListDataType(AttrHolderAdapter &&obj, const string &name, const vector<ge::DataType> &value);
static bool GetListDataType(ConstAttrHolderAdapter &&obj, const string &name, vector<ge::DataType> &value);

static bool SetDataType(AttrHolderAdapter &&obj, const string &name, const ge::DataType &value);
static bool GetDataType(ConstAttrHolderAdapter &&obj, const string &name, ge::DataType &value);

static OpDescPtr CloneOpDesc(const ConstOpDescPtr &orgOpDesc);

static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc);

static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj);

class AttrHolderAdapter {
public:
AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {}
~AttrHolderAdapter() {}
template <class T>
AttrHolderAdapter(const std::shared_ptr<T> &obj) : obj_(obj.get()) {}
AttrHolderAdapter(AttrHolder &obj) : obj_(&obj) {}
operator bool() const { return obj_ != nullptr; }
AttrHolder *operator->() { return obj_; }
AttrHolder *get() { return obj_; }

AttrHolder *obj_;
};

class ConstAttrHolderAdapter {
public:
ConstAttrHolderAdapter(const AttrHolder *obj) : obj_(obj) {}
~ConstAttrHolderAdapter() {}
template <class T>
ConstAttrHolderAdapter(const std::shared_ptr<T> obj) : obj_(obj.get()) {}
ConstAttrHolderAdapter(const AttrHolder &obj) : obj_(&obj) {}
operator bool() const { return obj_ != nullptr; }
const AttrHolder *operator->() const { return obj_; }
const AttrHolder *get() const { return obj_; }

private:
const AttrHolder *obj_;
};
};
} // namespace ge
#endif // INC_GRAPH_UTILS_ATTR_UTILS_H_

+ 0
- 771
metadef/inc/graph/utils/graph_utils.h View File

@@ -1,771 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_
#define INC_GRAPH_UTILS_GRAPH_UTILS_H_

#include <fstream>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include <list>
#include <unordered_map>

#include "graph/anchor.h"
#include "graph/node.h"
#include "graph/compute_graph.h"
#include "graph/utils/anchor_utils.h"
#include "graph/graph.h"
#include "graph/model.h"

#define GE_DUMP(compute_graph, name) \
do { \
GraphUtils::DumpGEGraph(compute_graph, name); \
GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \
uint64_t i = 0; \
for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \
auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \
GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \
GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \
} \
} while (0)

#define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
do { \
DataType ret; \
attr.GetValue<DataType>(ret); \
} while (0)

#define PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \
do { \
if (value_type == VT_ENUM) { \
REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
stream << ret; \
} \
} while (0)

#define PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \
do { \
if (value_type == VT_ENUM) { \
REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
stream << "["; \
for (int i = 0; i < ret.size(); i++) { \
stream << ret[i]; \
if (i + 1 != ret.size()) stream << ", "; \
} \
stream << "]"; \
} \
} while (0)

#define PRINT_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \
else PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream)

#define PRINT_LIST_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \
else PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream)

#define PRINT_SHAPE(i_o, n, idx, stream) \
do { \
auto op = n->GetOpDesc(); \
GeTensorDesc td = i_o == "input" ? op->GetInputDesc(idx) : op->GetOutputDesc(idx); \
auto shape = td.GetShape().GetDims(); \
stream << "["; \
for (int i = 0; i < shape.size(); i++) { \
stream << shape[i]; \
if (i + 1 < shape.size()) stream << ", "; \
} \
stream << "]"; \
} while (0)

#define PRINT_ATTR_FUNC(stream) \
[&](GeAttrValue attr) { \
auto type = attr.GetValueType(); \
PRINT_ATTR_VALUE_IF(type, GeAttrValue::ValueType::VT_STRING, GeAttrValue::STR, attr, stream) \
PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_FLOAT, GeAttrValue::FLOAT, attr, stream) \
PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_BOOL, GeAttrValue::BOOL, attr, stream) \
PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_INT, GeAttrValue::INT, attr, stream) \
PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_STRING, GeAttrValue::LIST_STR, attr, stream) \
PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_FLOAT, GeAttrValue::LIST_FLOAT, attr, stream) \
PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_BOOL, GeAttrValue::LIST_BOOL, attr, stream) \
PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_INT, GeAttrValue::LIST_INT, attr, stream) \
else if (type == GeAttrValue::ValueType::VT_TENSOR_DESC) stream << "TENSOR_DESC"; \
else if (type == GeAttrValue::ValueType::VT_TENSOR) stream << "TENSOR"; \
else if (type == GeAttrValue::ValueType::VT_BYTES) stream << "BYTES"; \
else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR_DESC) stream << "LIST_TENSOR_DESC"; \
else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR) stream << "LIST_TENSOR"; \
else if (type == GeAttrValue::ValueType::VT_LIST_BYTES) stream << "LIST_BYTES"; \
};

namespace ge {
enum IOType { kIn, kOut };

struct NodeIndexIO {
NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type)
: node_(std::move(node)), index_(index), io_type_(io_type) {
if (node_ != nullptr) {
value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_);
}
}
NodeIndexIO(ge::NodePtr node, int index, IOType io_type)
: node_(std::move(node)), index_(static_cast<uint32_t>(index)), io_type_(io_type) {
if (node_ != nullptr) {
value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_);
}
}
~NodeIndexIO() {}

NodePtr node_ = nullptr;
uint32_t index_ = 0;
IOType io_type_ = kOut;
std::string value_;

const std::string &ToString() const { return value_; }
};

class GraphUtils {
public:
static ComputeGraphPtr GetComputeGraph(const Graph &graph);

static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph);

static graphStatus RecoverGraphOperators(const Graph &graph);

static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs);

static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);

static graphStatus AddEdge(const OutDataAnchorPtr &src, const Format &src_format, const InDataAnchorPtr &dst,
const Format &dst_format);

static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst);

static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst);

static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst);

// check whether src is link to dst and then remove
static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);

static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst);

static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst);

static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst);

static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst,
const InDataAnchorPtr &new_dst);

static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst,
const InControlAnchorPtr &new_dst);

static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst,
const NodePtr &new_node);

static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node);

static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node);

static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor,
const std::vector<OpDescPtr> &vec_op_desc);

///
/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst
/// @param [in] src
/// @param [in] dsts
/// @param [in] insert_node
/// @param [in] input_index
/// @param [in] output_index
/// @return graphStatus
///
static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0);

static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node);

static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node);

static void RecordOriginalNames(std::vector<ge::NodePtr> original_nodes, const ge::NodePtr &node);

static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node);

static bool MatchDumpStr(const std::string &suffix);

static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false,
const std::string &user_graph_name = "");

static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph);

static bool LoadGEGraph(const char *file, ge::ComputeGraphPtr &compute_graph);

static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos);

static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix);

static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph);

static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message);

static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *real_path);

static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node);

///
/// Isolating `node`, relinking data links from the in-anchor peer nodes to
/// the out-anchor peer nodes according to `io_map`, relinking control links
/// to ensure that input nodes of `node` are before out nodes
///
/// Link the `io_map[i]` input anchor peer node to `i` output anchor peer
/// nodes, then unlink all links connecting with `node`. If `io_map[i]` < 0,
/// unlink all links from `i` output anchor without any relinking.
///
/// @param node
/// @param io_map
/// @return
///
static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list<int> &io_map);
static graphStatus IsolateNode(const NodePtr &node, const std::vector<int> &io_map);

///
/// Isolate `node` which must be one input one output, equivalent to
/// `IsolateNode(node, {0})`
/// @param node
/// @return
///
static graphStatus IsolateNodeOneIO(const NodePtr &node);

///
/// The data anchors replacing behavior is the same with
/// `ReplaceNodeDataAnchors`. In addition, replace all `old_node` control
/// anchors with `new_node`'s.
/// @param new_node
/// @param old_node
/// @param inputs_map
/// @param outputs_map
/// @return
///
static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node,
std::initializer_list<int> inputs_map, std::initializer_list<int> outputs_map);

static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node,
const std::vector<int> &inputs_map, const std::vector<int> &outputs_map);

///
/// Replace `old_node` data anchors with `new_node`'s according to `inputs_map` and `outputs_map`.
/// Replace the `i` in/out data anchor on `old_node` with
/// `inputs_map[i]`/`outputs_map[i]` data anchor on `new_node`.
/// If `inputs_map[i]`/`outputs_map[i]` < 0 or the index not contained in
/// `inputs_map[i]`/`outputs_map[i]`, the `i` data anchor will remain
/// on `old_node`.
/// @param new_node
/// @param old_node
/// @param inputs_map
/// @param outputs_map
/// @return
///
static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node,
std::initializer_list<int> inputs_map,
std::initializer_list<int> outputs_map);

static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node,
const std::vector<int> &inputs_map, const std::vector<int> &outputs_map);

///
/// Copy all in-control edges from `src_node` to `dst_node`
/// @param src_node
/// @param dst_node
/// @return
///
static graphStatus CopyInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);

static graphStatus MoveInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);

///
/// Copy all out-control edges from `src_node` to `dst_node`
/// @param src_node
/// @param dst_node
/// @return success: GRAPH_SUCESS
///
static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);

///
/// Move all out-control edges from `src_node` to `dst_node`
/// @param src_node
/// @param dst_node
/// @return success: GRAPH_SUCESS
///
static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node);

///
/// Copy all in-data edges from `src_node` to `dst_node`
/// @param src_node
/// @param dst_node
/// @return
///
static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node);

static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph);

///
/// Make a copy of ComputeGraph.
/// @param graph: original graph.
/// @param prefix: node name prefix of new graph.
/// @return ComputeGraphPtr
///
static ComputeGraphPtr CloneGraph(const ComputeGraphPtr &graph, const string &prefix,
std::vector<NodePtr> &input_nodes, std::vector<NodePtr> &output_nodes);

///
/// Copy tensor attribute to new node.
/// @param [in] dst_desc: cloned node.
/// @param [in] src_node: original node.
/// @return success: GRAPH_SUCESS
///
static graphStatus CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node);

static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec);

///
/// Get reference-mapping of all data_anchors in graph
/// @param [in] graph
/// @param [out] symbol_to_anchors
/// @param [out] anchor_to_symbol
/// @return success: GRAPH_SUCESS
///
static graphStatus GetRefMapping(const ComputeGraphPtr &graph,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol);

///
/// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs
/// of the graph have UNKNOWN_SHAPE operators or not.
/// Note: This function will only look 'down' from the graph, not 'up'. For example, the following
/// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE
/// ROOT graph: A -----> B -----> C
/// K subgraph U
/// |
/// V
/// SUB graph: D --> E --> F
/// K K K
/// @param [in] graph
/// @return bool
///
static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph);

static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name);

private:
///
/// Get reference-mapping for in_data_anchors of node
/// @param [in] node
/// @param [out] symbol_to_anchors
/// @param [out] anchor_to_symbol
/// @return success: GRAPH_SUCESS
///
static graphStatus HandleInAnchorMapping(const NodePtr &node,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol);

///
/// Get reference-mapping for out_data_anchors of node
/// @param [in] node
/// @param [out] symbol_to_anchors
/// @param [out] anchor_to_symbol
/// @return success: GRAPH_SUCESS
///
static graphStatus HandleOutAnchorMapping(const NodePtr &node,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol);

///
/// Handle input of subgraph
/// @param [in] node
/// @param [out] symbol_to_anchors
/// @param [out] anchor_to_symbol
/// @return success: GRAPH_SUCESS
///
static graphStatus HandleSubgraphInput(const NodePtr &node,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol);

///
/// Handle input of Merge op
/// @param [in] node
/// @param [out] symbol_to_anchors
/// @param [out] anchor_to_symbol
/// @return success: GRAPH_SUCESS
///
static graphStatus HandleMergeInput(const NodePtr &node,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol);

///
/// Handle output of subgraph
/// @param [in] node
/// @param [out] symbol_to_anchors
/// @param [out] anchor_to_symbol
/// @return success: GRAPH_SUCESS
///
static graphStatus HandleSubgraphOutput(const NodePtr &node,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol);

///
/// Relink all edges for cloned ComputeGraph.
/// @param [in] node: original node.
/// @param [in] prefix: node name prefix of new node.
/// @param [in] all_nodes: all nodes in new graph.
/// @return success: GRAPH_SUCESS
///
static graphStatus RelinkGraphEdges(const NodePtr &node, const string &prefix,
const std::unordered_map<string, NodePtr> &all_nodes);

///
/// Union ref-mapping
/// @param [in] exist_node_info1
/// @param [in] exist_node_info2
/// @param [out] symbol_to_anchors
/// @param [out] anchor_to_symbol
/// @param [out] symbol
/// @return success: GRAPH_SUCESS
///
static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol);

///
/// Update symbol mapping with a new reference pair
/// @param [in] cur_node_info
/// @param [in] exist_node_info
/// @param [out] symbol_to_anchors
/// @param [out] anchor_to_symbol
/// @return success: GRAPH_SUCESS
///
static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol);

///
/// Check if out_data_anchor is reference of input
/// @param [in] out_data_anchor
/// @param [out] reuse_in_index
/// @return bool
///
static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index);
};

class ComputeGraphBuilder {
public:
ComputeGraphBuilder() : owner_graph_(nullptr) {}
ComputeGraphBuilder(const ComputeGraphBuilder &) = delete;
ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete;
ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete;
ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete;
~ComputeGraphBuilder() = default;

///
/// @brief Add node to graph
/// @param [in] op_desc
/// @return ComputeGraphBuilder
///
virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc);

///
/// @brief Add data-link among nodes in graph
/// @param [in] src_name
/// @param [in] out_anchor_ind
/// @param [in] dst_name
/// @param [in] in_anchor_ind
/// @return ComputeGraphBuilder
///
virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind,
const std::string &dst_name, uint32_t in_anchor_ind);

///
/// @brief Add ctrl-link among nodes in graph
/// @param [in] src_name
/// @param [in] dst_name
/// @return ComputeGraphBuilder
///
virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name);

///
/// @brief Build graph
/// @param [out] error_code
/// @param [out] error_msg
/// @return ComputeGraphPtr
///
virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0;

/// @brief Get node with name
/// @param [in] name
/// @return NodePtr
///
NodePtr GetNode(const std::string &name);

/// @brief Get all nodes
/// @return std::vector<NodePtr>
///
std::vector<NodePtr> GetAllNodes();

protected:
///
/// @brief Build nodes
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildNodes(graphStatus &error_code, std::string &error_msg);

///
/// @brief Build data-links
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildDataLinks(graphStatus &error_code, std::string &error_msg);

///
/// @brief Build ctrl-links
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg);

ComputeGraphPtr owner_graph_;

// node_name -> node
std::map<std::string, NodePtr> node_names_;
std::vector<OpDescPtr> nodes_;

// <src_node_name, out_anchor_ind> -> <dst_node_name, in_anchor_ind>
std::vector<std::pair<std::pair<std::string, uint32_t>, std::pair<std::string, uint32_t>>> data_links_;
// src_node_name -> dst_node_name
std::vector<std::pair<std::string, std::string>> ctrl_links_;
};

class CompleteGraphBuilder : public ComputeGraphBuilder {
public:
explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {}
CompleteGraphBuilder(const CompleteGraphBuilder &) = delete;
CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete;
CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete;
CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete;
~CompleteGraphBuilder() = default;

///
/// @brief Add node to graph
/// @param [in] op_desc
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override;

///
/// @brief Add data-link among nodes in graph
/// @param [in] src_name
/// @param [in] out_anchor_ind
/// @param [in] dst_name
/// @param [in] in_anchor_ind
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
uint32_t in_anchor_ind) override;

///
/// @brief Add ctrl-link among nodes in graph
/// @param [in] src_name
/// @param [in] dst_name
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;

///
/// @brief Set index_th input anchor for graph
/// @param [in] index
/// @param [in] node_names
/// @param [in] anchor_inds
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &SetInput(uint32_t index, const std::vector<std::string> &node_names,
const std::vector<uint32_t> &anchor_inds);

///
/// @brief Set index_th input of graph as useless
/// @param [in] index
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &SetUselessInput(uint32_t index);

///
/// @brief Add output anchor for graph
/// @param [in] owner_node_name
/// @param [in] anchor_ind
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind);

///
/// @brief Add target for graph
/// @param [in] target_name
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &AddTarget(const std::string &target_name);

///
/// @brief Set parent-node of graph
/// @param [in] parent_node
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node);

///
/// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node
/// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &SetInputMapping(const std::map<uint32_t, uint32_t> &input_mapping);

///
/// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind
/// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &SetOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping);

///
/// @brief Build graph
/// @param [out] error_code
/// @param [out] error_msg
/// @return ComputeGraphPtr
///
ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;

private:
///
/// @brief Add data nodes
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void AddDataNodes(graphStatus &error_code, std::string &error_msg);

///
/// @brief Add data node
/// @param [in] index
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg);

///
/// @brief Add RetVal nodes
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void AddRetValNodes(graphStatus &error_code, std::string &error_msg);

///
/// @brief Build target-nodes for graph
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildGraphTargets(graphStatus &error_code, std::string &error_msg);

std::string name_;
NodePtr parent_node_;
std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_;
std::vector<std::pair<std::string, uint32_t>> graph_outputs_;
std::vector<std::string> graph_targets_;

// index_of_graph_input -> in_anchor_index_of_parent_node
std::map<uint32_t, uint32_t> input_mapping_;
// index_of_graph_output -> out_anchor_index_of_parent_node
std::map<uint32_t, uint32_t> output_mapping_;
};

class PartialGraphBuilder : public ComputeGraphBuilder {
public:
PartialGraphBuilder() = default;
PartialGraphBuilder(const PartialGraphBuilder &) = delete;
PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete;
PartialGraphBuilder(const PartialGraphBuilder &&) = delete;
PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete;
~PartialGraphBuilder() = default;

///
/// @brief Add node to graph
/// @param [in] op_desc
/// @return PartialGraphBuilder
///
PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override;

///
/// @brief Add data-link among nodes in graph
/// @param [in] src_name
/// @param [in] out_anchor_ind
/// @param [in] dst_name
/// @param [in] in_anchor_ind
/// @return PartialGraphBuilder
///
PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
uint32_t in_anchor_ind) override;

///
/// @brief Add ctrl-link among nodes in graph
/// @param [in] src_name
/// @param [in] dst_name
/// @return PartialGraphBuilder
///
PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;

///
/// @brief Set owner graph
/// @param [in] graph
/// @return PartialGraphBuilder
///
PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph);

///
/// @brief Add exist node
/// @param [in] node
/// @return PartialGraphBuilder
///
PartialGraphBuilder &AddExistNode(const NodePtr &node);

///
/// @brief Build multi nodes with links
/// @param [out] error_code
/// @param [out] error_msg
/// @return ComputeGraphPtr
///
ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;

private:
///
/// @brief Build exist nodes
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildExistNodes(graphStatus &error_code, std::string &error_msg);

std::vector<NodePtr> exist_nodes_;
};
} // namespace ge
#endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_

+ 0
- 170
metadef/inc/graph/utils/node_utils.h View File

@@ -1,170 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_UTILS_NODE_UTILS_H_
#define INC_GRAPH_UTILS_NODE_UTILS_H_

#include <set>
#include <map>
#include <vector>
#include "external/graph/operator.h"
#include "graph/node.h"

namespace ge {
// Op types of Const like Opps.
extern const std::set<std::string> kConstOpTypes;
// Op types of If like Opps.
extern const std::set<std::string> kIfOpTypes;
// Op types of While like Opps.
extern const std::set<std::string> kWhileOpTypes;
// Op types of Case like Opps.
extern const std::set<std::string> kCaseOpTypes;
// Op types of For like Opps.
extern const std::set<std::string> kForOpTypes;

class NodeUtils {
public:
static graphStatus AddSendEventId(const NodePtr &node, const uint32_t &event_id);
static graphStatus AddRecvEventId(const NodePtr &node, const uint32_t &event_id);
static graphStatus GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send);
static graphStatus GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv);

static graphStatus ClearSendInfo();
static graphStatus ClearRecvInfo();

static graphStatus GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst);

static graphStatus GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data,
InControlAnchorPtr &in_control);

static graphStatus ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor);
static graphStatus SetAllAnchorStatus(const NodePtr &nodePtr);
static graphStatus SetAllAnchorStatus(Node &node);
static bool IsAnchorStatusSet(const NodePtr &nodePtr);
static bool IsAnchorStatusSet(const Node &node);

static graphStatus MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node);

static void UpdateIsInputConst(const NodePtr &nodePtr);
static void UpdateIsInputConst(Node &node);
static bool IsConst(const Node &node);
static void UnlinkAll(const Node &node);
static graphStatus UpdatePeerNodeInputDesc(const NodePtr &node_ptr);

static graphStatus AppendInputAnchor(const NodePtr &node, uint32_t num);
static graphStatus RemoveInputAnchor(const NodePtr &node, uint32_t num);

static graphStatus AppendOutputAnchor(const NodePtr &node, uint32_t num);
static graphStatus RemoveOutputAnchor(const NodePtr &node, uint32_t num);

static bool IsInNodesEmpty(const Node &node);
static GeTensorDesc GetOutputDesc(const Node &node, uint32_t index);
static GeTensorDesc GetInputDesc(const Node &node, uint32_t index);
static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape);
static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape);
// check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true;
// for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2,
// the out param "is_unknow" will be true too
static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow);

static std::string GetNodeType(const Node &node);
static std::string GetNodeType(const NodePtr &node);

static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index);
static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph);

///
/// Check if node is input of subgraph
/// @param [in] node
/// @return bool
///
static bool IsSubgraphInput(const NodePtr &node);

///
/// Check if node is output of subgraph
/// @param [in] node
/// @return bool
///
static bool IsSubgraphOutput(const NodePtr &node);

///
/// @brief Get subgraph original input node.
/// @param [in] node
/// @return Node
///
static NodePtr GetParentInput(const Node &node);
static NodePtr GetParentInput(const NodePtr &node);

///
/// @brief Get is dynamic shape graph from node.
/// @param [in] node
/// @return bool
///
static bool IsDynamicShape(const Node &node);
static bool IsDynamicShape(const NodePtr &node);

///
/// @brief Check is varying_input for while node
/// @param [in] node: Data node for subgraph
/// @return bool
///
static bool IsWhileVaryingInput(const ge::NodePtr &node);

///
/// @brief Get subgraph input is constant.
/// @param [in] node
/// @param [out] string
/// @return bool
///
static bool GetConstOpType(const NodePtr &node, std::string &type);

///
/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph.
/// @param [in] node
/// @return return GRAPH_SUCCESS if remove successfully, other for failed.
///
static graphStatus RemoveSubgraphsOnNode(const NodePtr &node);

///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
static vector<NodePtr> GetSubgraphDataNodesByIndex(const Node &node, int index);

///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
static vector<NodePtr> GetSubgraphOutputNodes(const Node &node);

static NodePtr GetInDataNodeByIndex(const Node &node, const int index);

static vector<pair<InDataAnchorPtr, NodePtr>> GetOutDataNodesWithAnchorByIndex(const Node &node, const int index);

static ge::ConstNodePtr GetNodeFromOperator(const Operator &oprt);

static graphStatus GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor);

static graphStatus GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor);

private:
static std::map<NodePtr, std::vector<uint32_t>> map_send_info_;
static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_;
};
} // namespace ge
#endif // INC_GRAPH_UTILS_NODE_UTILS_H_

+ 0
- 181
metadef/inc/graph/utils/op_desc_utils.h View File

@@ -1,181 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_UTILS_OP_DESC_UTILS_H_
#define INC_GRAPH_UTILS_OP_DESC_UTILS_H_

#include <memory>
#include <string>
#include <vector>
#include "graph/def_types.h"
#include "graph/node.h"
#include "graph/op_desc.h"
#include "graph/operator.h"
#include "graph/range_vistor.h"

namespace ge {
class OpDesc;
using OpDescPtr = std::shared_ptr<OpDesc>;

class OpDescUtils {
public:
template <class T>
using Vistor = RangeVistor<T, std::shared_ptr<OpDesc>>;

OpDescUtils() = default;
~OpDescUtils() = default;
static bool HasQuantizeFactorParams(const OpDescPtr& op_desc);
static bool HasQuantizeFactorParams(const OpDesc& op_desc);
static graphStatus GetQuantizeFactorParams(const OpDescPtr& op_desc, QuantizeFactorParams& quant);
static graphStatus GetQuantizeFactorParams(const OpDesc& op_desc, QuantizeFactorParams& quant);
static graphStatus SetQuantizeFactorParams(const OpDescPtr& op_desc, const QuantizeFactorParams& quant);
static graphStatus SetQuantizeFactorParams(OpDesc& op_desc, const QuantizeFactorParams& quant);

static vector<ge::NodePtr> GetConstInputNode(const ge::Node& node);
static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr>& input_nodes);

static vector<ConstGeTensorPtr> GetWeights(const ge::Node& node);
static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr& node);
static vector<GeTensorPtr> MutableWeights(const ge::Node& node);
static vector<GeTensorPtr> MutableWeights(const ge::NodePtr node);
static graphStatus SetWeights(ge::Node& node, const vector<ge::GeTensorPtr>& weights);
static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr>& weights);
static graphStatus ClearWeights(ge::NodePtr node);

static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index);
static bool ClearInputDesc(const ge::NodePtr& node);
static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, uint32_t index);
static bool ClearOutputDesc(const ge::NodePtr& node);
static vector<ge::NodePtr> GetConstInputs(const ge::Node& node);
static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr& node);
static size_t GetNonConstInputsSize(const ge::Node& node);
static size_t GetNonConstInputsSize(ge::ConstNodePtr node);
// Index: Indicates the index of all non const inputs
static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, size_t index_non_const = 0);
static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, size_t index_non_const = 0);
static bool GetNonConstInputIndex(const ge::Node& node, size_t index_non_const, size_t& index);
static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, size_t index_non_const, size_t& index);
// Index: Indicates the index of all inputs
static bool IsNonConstInput(const ge::Node& node, size_t index = 0);
static bool IsNonConstInput(const ge::ConstNodePtr& node, size_t index = 0);

static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr& node);
static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr);

static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc);
static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr);
static OpDescPtr GetOpDescFromOperator(const Operator& oprt);

static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr);

static graphStatus SetSubgraphInstanceName(const std::string& subgraph_name,
const std::string& subgraph_instance_name, OpDescPtr& op_desc);

private:
static GeTensorPtr MutableWeights(ge::OpDesc& op_desc);
static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc);
static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight);
static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight);
};

class OpDescBuilder {
public:
OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {}
OpDescBuilder(const OpDescBuilder&) = delete;
OpDescBuilder& operator=(const OpDescBuilder&) = delete;
OpDescBuilder(const OpDescBuilder&&) = delete;
OpDescBuilder& operator=(const OpDescBuilder&&) = delete;
~OpDescBuilder() = default;

///
/// @brief Add input
/// @param [in] name
/// @return OpDescBuilder
///
OpDescBuilder& AddInput(const std::string& name);

///
/// @brief Add input
/// @param [in] name
/// @param [in] tensor
/// @return OpDescBuilder
///
OpDescBuilder& AddInput(const std::string& name, const GeTensorDesc& tensor);

///
/// @brief Add dynamic input
/// @param [in] name
/// @param [in] num
/// @return OpDescBuilder
///
OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num);

///
/// @brief Add dynamic input
/// @param [in] name
/// @param [in] num
/// @param [in] tensor
/// @return OpDescBuilder
///
OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num, const GeTensorDesc& tensor);

///
/// @brief Add output
/// @param [in] name
/// @return OpDescBuilder
///
OpDescBuilder& AddOutput(const std::string& name);

///
/// @brief Add output
/// @param [in] name
/// @param [in] tensor
/// @return OpDescBuilder
///
OpDescBuilder& AddOutput(const std::string& name, const GeTensorDesc& tensor);

///
/// @brief Add dynamic output
/// @param [in] name
/// @param [in] num
/// @return OpDescBuilder
///
OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num);

///
/// @brief Add dynamic output
/// @param [in] name
/// @param [in] num
/// @param [in] tensor
/// @return OpDescBuilder
///
OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num, const GeTensorDesc& tensor);

///
/// @brief Build op_desc
/// @return OpDescPtr
///
OpDescPtr Build();

private:
std::string name_;
std::string type_;
std::vector<std::pair<std::string, GeTensorDesc>> inputs_;
std::vector<std::pair<std::string, GeTensorDesc>> outputs_;
};
} // namespace ge

#endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_

+ 0
- 43
metadef/inc/graph/utils/tensor_adapter.h View File

@@ -1,43 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_UTILS_TENSOR_ADAPTER_H_
#define INC_GRAPH_UTILS_TENSOR_ADAPTER_H_

#include <memory>
#include "graph/ge_tensor.h"
#include "graph/tensor.h"

namespace ge {
using GeTensorPtr = std::shared_ptr<GeTensor>;
using ConstGeTensorPtr = std::shared_ptr<const GeTensor>;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorAdapter {
public:
static GeTensorDesc TensorDesc2GeTensorDesc(const TensorDesc &tensorDesc);
static TensorDesc GeTensorDesc2TensorDesc(const GeTensorDesc &geTensorDesc);
static GeTensorPtr Tensor2GeTensor(const Tensor &tensor);
static Tensor GeTensor2Tensor(const ConstGeTensorPtr &geTensor);

static ConstGeTensorPtr AsGeTensorPtr(const Tensor &tensor); // Share value
static GeTensorPtr AsGeTensorPtr(Tensor &tensor); // Share value
static const GeTensor AsGeTensor(const Tensor &tensor); // Share value
static GeTensor AsGeTensor(Tensor &tensor); // Share value
static const Tensor AsTensor(const GeTensor &tensor); // Share value
static Tensor AsTensor(GeTensor &tensor); // Share value
};
} // namespace ge
#endif // INC_GRAPH_UTILS_TENSOR_ADAPTER_H_

+ 0
- 77
metadef/inc/graph/utils/tensor_utils.h View File

@@ -1,77 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_UTILS_TENSOR_UTILS_H_
#define INC_GRAPH_UTILS_TENSOR_UTILS_H_

#include <vector>
#include "graph/def_types.h"
#include "graph/ge_error_codes.h"
#include "graph/ge_tensor.h"

namespace ge {
class TensorUtils {
public:
static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, int64_t &size);
static void SetSize(GeTensorDesc &tensorDesc, int64_t size);
static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr);
static uint32_t GetWeightSize(const GeTensor &tensor);
static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc);
static uint8_t *GetWeightAddr(const ConstGeTensorPtr &tensorPtr, uint8_t *base);
static uint8_t *GetWeightAddr(const GeTensor &tensor, uint8_t *base);
static void SetWeightSize(GeTensorDesc &tensorDesc, uint32_t size);
static ge::graphStatus GetReuseInput(const GeTensorDesc &tensorDesc, bool &flag);
static void SetReuseInput(GeTensorDesc &tensorDesc, bool flag);
static ge::graphStatus GetOutputTensor(const GeTensorDesc &tensorDesc, bool &flag);
static void SetOutputTensor(GeTensorDesc &tensorDesc, bool flag);
static graphStatus GetDeviceType(const GeTensorDesc &tensorDesc, DeviceType &type);
static void SetDeviceType(GeTensorDesc &tensorDesc, DeviceType type);
static ge::graphStatus GetInputTensor(const GeTensorDesc &tensorDesc, bool &flag);
static void SetInputTensor(GeTensorDesc &tensorDesc, bool flag);
static ge::graphStatus GetRealDimCnt(const GeTensorDesc &tensorDesc, uint32_t &cnt);
static void SetRealDimCnt(GeTensorDesc &tensorDesc, uint32_t cnt);
static ge::graphStatus GetReuseInputIndex(const GeTensorDesc &tensorDesc, uint32_t &idx);
static void SetReuseInputIndex(GeTensorDesc &tensorDesc, uint32_t idx);
static ge::graphStatus GetDataOffset(const GeTensorDesc &tensorDesc, int64_t &offset);
static void SetDataOffset(GeTensorDesc &tensorDesc, int64_t offset);
static ge::graphStatus GetCmpsSize(const GeTensorDesc &tensorDesc, uint32_t &cmp_size);
static void SetCmpsSize(GeTensorDesc &tensorDesc, uint32_t cmp_size);
static ge::graphStatus GetCmpsTab(const GeTensorDesc &tensorDesc, vector<uint8_t> &vec);
static void SetCmpsTab(GeTensorDesc &tensorDesc, const uint8_t *data, size_t size);
static ge::graphStatus GetCmpsTabOffset(const GeTensorDesc &tensorDesc, int64_t &tab_offset);
static void SetCmpsTabOffset(GeTensorDesc &tensorDesc, int64_t tab_offset);
static ge::graphStatus GetCmpsInfo(const GeTensorDesc &tensorDesc, CompressInfo &info);
static void SetCmpsInfo(GeTensorDesc &tensorDesc, const CompressInfo &info);
static bool HasAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc);
static ge::graphStatus GetAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc, AllOffsetQuantizeInfo &info);
static void SetAlloffsetQuantizeInfo(GeTensorDesc &tensorDesc, const AllOffsetQuantizeInfo &info);
static ge::graphStatus GetRC(const GeTensorDesc &tensorDesc, uint32_t &rc);
static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc);

///
/// calculate tensor mem size.
/// @param shape tensor shape
/// @param format tensor format
/// @param data_type tensor data type
/// @param mem_size -1 means unknown shape,other means mem size
/// @return GRAPH_SUCCESS:success, other:failed
///
static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size);
static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp);
static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp);
};
} // namespace ge
#endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_

+ 0
- 53
metadef/inc/graph/utils/type_utils.h View File

@@ -1,53 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_UTILS_TYPE_UTILS_H_
#define INC_GRAPH_UTILS_TYPE_UTILS_H_

#include <map>
#include <unordered_set>
#include <string>
#include "graph/def_types.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"
#include "graph/usr_types.h"
#include "register/register_types.h"
#include "external/register/register_fmk_types.h"

namespace ge {
class TypeUtils {
public:
static bool IsDataTypeValid(DataType dt);
static bool IsFormatValid(Format format);
static bool IsInternalFormat(Format format);

static std::string ImplyTypeToSerialString(domi::ImplyType imply_type);
static std::string DataTypeToSerialString(DataType data_type);
static DataType SerialStringToDataType(const std::string &str);
static std::string FormatToSerialString(Format format);
static Format SerialStringToFormat(const std::string &str);
static Format DataFormatToFormat(const std::string &str);
static Format DomiFormatToFormat(domi::domiTensorFormat_t domi_format);
static std::string FmkTypeToSerialString(domi::FrameworkType fmk_type);

static graphStatus Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def);
static graphStatus Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr);

static bool GetDataTypeLength(ge::DataType data_type, uint32_t &length);
static bool CheckUint64MulOverflow(uint64_t a, uint32_t b);
};
} // namespace ge
#endif // INC_GRAPH_UTILS_TYPE_UTILS_H_

+ 0
- 127
metadef/proto/dump_task.proto View File

@@ -1,127 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

syntax = "proto3";
package toolkit.dumpdata;

enum OutputDataType {
DT_UNDEFINED = 0;
DT_FLOAT = 1;
DT_FLOAT16 = 2;
DT_INT8 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_UINT16 = 6;
DT_INT32 = 7;
DT_INT64 = 8;
DT_UINT32 = 9;
DT_UINT64 = 10;
DT_BOOL = 11;
DT_DOUBLE = 12;
DT_STRING = 13;
DT_DUAL_SUB_INT8 = 14;
DT_DUAL_SUB_UINT8 = 15;
DT_COMPLEX64 = 16;
DT_COMPLEX128 = 17;
DT_QINT8 = 18;
DT_QINT16 = 19;
DT_QINT32 = 20;
DT_QUINT8 = 21;
DT_QUINT16 = 22;
DT_RESOURCE = 23;
DT_STRING_REF = 24;
DT_DUAL = 25;
}

enum OutputFormat {
FORMAT_NCHW = 0;
FORMAT_NHWC = 1;
FORMAT_ND = 2;
FORMAT_NC1HWC0 = 3;
FORMAT_FRACTAL_Z = 4;
FORMAT_NC1C0HWPAD = 5;
FORMAT_NHWC1C0 = 6;
FORMAT_FSR_NCHW = 7;
FORMAT_FRACTAL_DECONV = 8;
FORMAT_C1HWNC0 = 9;
FORMAT_FRACTAL_DECONV_TRANSPOSE = 10;
FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11;
FORMAT_NC1HWC0_C04 = 12;
FORMAT_FRACTAL_Z_C04 = 13;
FORMAT_CHWN = 14;
FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15;
FORMAT_HWCN = 16;
FORMAT_NC1KHKWHWC0 = 17;
FORMAT_BN_WEIGHT = 18;
FORMAT_FILTER_HWCK = 19;
FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20;
FORMAT_HASHTABLE_LOOKUP_KEYS = 21;
FORMAT_HASHTABLE_LOOKUP_VALUE = 22;
FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23;
FORMAT_HASHTABLE_LOOKUP_HITS=24;
FORMAT_C1HWNCoC0 = 25;
FORMAT_MD = 26;
FORMAT_NDHWC = 27;
FORMAT_FRACTAL_ZZ = 28;
FORMAT_FRACTAL_NZ = 29;
FORMAT_RESERVED = 30;
}

message OriginalOp {
string name = 1;
uint32 output_index = 2;
OutputDataType data_type = 3;
OutputFormat format = 4;
}

message Shape {
repeated uint64 dim = 1;
}

message OpOutput {
OutputDataType data_type = 1;
OutputFormat format = 2;
Shape shape = 3;
OriginalOp original_op = 4; // the original op corresponding to the output
bytes data = 5;
uint64 size = 6;
}

message OpInput {
OutputDataType data_type = 1;
OutputFormat format = 2;
Shape shape = 3;
bytes data = 4;
uint64 size = 5;
}

enum BufferType {
L1 = 0;
}

message OpBuffer {
BufferType buffer_type = 1;
bytes data = 2;
uint64 size = 3;
}

message DumpData{
string version = 1;
uint64 dump_time = 2;
repeated OpOutput output = 3;
repeated OpInput input = 4;
repeated OpBuffer buffer = 5;
}

+ 0
- 26
metadef/proto/fusion_model.proto View File

@@ -1,26 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

syntax = "proto3";

import "om.proto";

package domi;

message FusionModelDef {
string version = 1;
repeated OpDef fusion_op = 2;
}

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save