From ea410fb3186df0e9c412da6237f17bc632b6f782 Mon Sep 17 00:00:00 2001 From: zhangzhenghai Date: Tue, 4 Aug 2020 15:43:04 +0800 Subject: [PATCH] update common/graph, modify libgraph.so --- src/common/graph/CMakeLists.txt | 1 - src/common/graph/compute_graph.cc | 13 - src/common/graph/debug/ge_op_types.h | 4 - src/common/graph/format_refiner.cc | 90 +--- src/common/graph/format_refiner.h | 8 +- src/common/graph/ge_attr_define.cc | 21 +- src/common/graph/ge_tensor.cc | 11 - src/common/graph/graph.cc | 2 +- src/common/graph/graph.mk | 74 +-- src/common/graph/model_serialize.cc | 25 +- src/common/graph/node.cc | 47 +- src/common/graph/op_desc.cc | 53 +-- src/common/graph/operator.cc | 73 +-- src/common/graph/option/ge_context.cc | 2 - src/common/graph/ref_relation.cc | 4 - src/common/graph/shape_refiner.cc | 172 +------ src/common/graph/stub/Makefile | 6 + src/common/graph/stub/gen_stubapi.py | 573 ++++++++++++++++++++++++ src/common/graph/utils/ge_ir_utils.h | 10 +- src/common/graph/utils/graph_utils.cc | 79 +--- src/common/graph/utils/node_utils.cc | 171 +------ src/common/graph/utils/op_desc_utils.cc | 66 +-- src/common/graph/utils/tensor_utils.cc | 8 +- src/common/graph/utils/type_utils.cc | 3 +- 24 files changed, 779 insertions(+), 737 deletions(-) create mode 100644 src/common/graph/stub/Makefile create mode 100644 src/common/graph/stub/gen_stubapi.py diff --git a/src/common/graph/CMakeLists.txt b/src/common/graph/CMakeLists.txt index f041e4b6..43f5b597 100755 --- a/src/common/graph/CMakeLists.txt +++ b/src/common/graph/CMakeLists.txt @@ -71,6 +71,5 @@ target_link_libraries(graph PRIVATE ${PROTOBUF_LIBRARY} ${c_sec} ${slog} - ${error_manager} rt dl) diff --git a/src/common/graph/compute_graph.cc b/src/common/graph/compute_graph.cc index 8a0c9f06..b73cf939 100644 --- a/src/common/graph/compute_graph.cc +++ b/src/common/graph/compute_graph.cc @@ -106,15 +106,6 @@ ComputeGraph::Vistor ComputeGraph::AllGraphNodes(std::vector(shared_from_this(), all_nodes); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetNodes( - bool is_unknown_shape) const { - if (is_unknown_shape) { - return GetDirectNode(); - } else { - return GetAllNodes(); - } -} - size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetDirectNode() const { @@ -506,10 +497,6 @@ ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptrGetName()) { GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); } - if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) { - GE_LOGE("The subgraph %s existed", name.c_str()); - return GRAPH_PARAM_INVALID; - } sub_graph_.push_back(subgraph); names_to_subgraph_[name] = subgraph; return GRAPH_SUCCESS; diff --git a/src/common/graph/debug/ge_op_types.h b/src/common/graph/debug/ge_op_types.h index f11ef31e..da36f72c 100644 --- a/src/common/graph/debug/ge_op_types.h +++ b/src/common/graph/debug/ge_op_types.h @@ -34,16 +34,12 @@ GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); GE_REGISTER_OPTYPE(SWITCH, "Switch"); GE_REGISTER_OPTYPE(MERGE, "Merge"); GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); -GE_REGISTER_OPTYPE(ENTER, "Enter"); -GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); GE_REGISTER_OPTYPE(CONSTANT, "Const"); -GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); GE_REGISTER_OPTYPE(INITDATA, "InitData"); -GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); diff --git a/src/common/graph/format_refiner.cc b/src/common/graph/format_refiner.cc index 9cb76539..11a610ce 100644 --- a/src/common/graph/format_refiner.cc +++ b/src/common/graph/format_refiner.cc @@ -41,9 +41,11 @@ using namespace ge; using namespace std; namespace ge { namespace { -const std::unordered_set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; -const string kIsGraphInferred = "_is_graph_inferred"; -RefRelations reflection_builder; +static const std::unordered_set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; +static bool net_format_is_nd = true; +static Format g_user_set_format = FORMAT_ND; +static bool is_first_infer = true; +static RefRelations reflection_builder; } // namespace graphStatus ReflectionProcess(const std::unordered_set &reflection, @@ -70,49 +72,9 @@ graphStatus ReflectionProcess(const std::unordered_set &re return GRAPH_SUCCESS; } -graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) { - // 5 meas dim num - if (node_ptr->GetType() != "BiasAdd") { - return GRAPH_SUCCESS; - } - std::unordered_map kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}}; - for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) { - auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i); - GE_CHECK_NOTNULL(in_desc); - if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num - continue; - } - auto format = in_desc->GetOriginFormat(); - auto key = TypeUtils::FormatToSerialString(format); - auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; - in_desc->SetOriginFormat(fixed_format); - in_desc->SetFormat(fixed_format); - GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i, - node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), - TypeUtils::FormatToSerialString(fixed_format).c_str()); - } - for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) { - auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i); - GE_CHECK_NOTNULL(out_desc); - if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num - continue; - } - auto format = out_desc->GetOriginFormat(); - auto key = TypeUtils::FormatToSerialString(format); - auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; - out_desc->SetOriginFormat(fixed_format); - out_desc->SetFormat(fixed_format); - GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i, - node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), - TypeUtils::FormatToSerialString(fixed_format).c_str()); - } - return GRAPH_SUCCESS; -} - -graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { - GE_CHECK_NOTNULL(graph); +graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { GE_CHECK_NOTNULL(op_desc); - if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) { + if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) { ConstGeTensorPtr tensor_value; if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str()); @@ -133,7 +95,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std } anchor_points.clear(); // Get all anchor point nodes and switch nodes - for (auto &node_ptr : graph->GetAllNodes()) { + for (const auto &node_ptr : graph->GetAllNodes()) { if (node_ptr == nullptr) { return GRAPH_FAILED; } @@ -141,7 +103,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std if (op_desc == nullptr) { return GRAPH_FAILED; } - graphStatus status = RefreshConstantOutProcess(graph, op_desc); + graphStatus status = RefreshConstantOutProcess(op_desc); if (status != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "refresh constant out process failed!"); return GRAPH_FAILED; @@ -173,16 +135,6 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std if (!node_is_all_nd) { continue; } - // special process for biasAdd op - // In tensorflow, biasAdd's format is alwayse NHWC even though set the arg - // "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism - // so here do special process - status = BiasAddFormatFixProcess(node_ptr); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "fix biasAdd process failed!"); - return GRAPH_FAILED; - } - GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str()); anchor_points.push_back(node_ptr); } @@ -392,11 +344,14 @@ void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector &anchor } } -graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, - ge::Format data_format, +void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = is_first; } + +graphStatus FormatRefiner::DataNodeFormatProcess(std::vector &data_nodes, ge::Format data_format, std::unordered_map &node_status) { - if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { - GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph), + bool is_internal_format = TypeUtils::IsInternalFormat(data_format); + bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND); + if (!need_process) { + GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer, TypeUtils::FormatToSerialString(data_format).c_str()); return GRAPH_SUCCESS; } @@ -455,6 +410,8 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) std::vector anchor_points; std::vector data_nodes; // global net format + net_format_is_nd = true; + g_user_set_format = FORMAT_ND; if (graph == nullptr) { GELOGE(GRAPH_FAILED, "input graph is null"); @@ -491,15 +448,10 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) /// format for these data nodes. /// Notice: ignore 5D formats auto data_format = graph->GetDataFormat(); - status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status); - - (void)AttrUtils::SetBool(graph, kIsGraphInferred, true); + status = DataNodeFormatProcess(data_nodes, data_format, node_status); + // Set infer flag to false + SetInferOrigineFormatFlag(false); return status; } - -bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { - bool is_graph_inferred = false; - return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); -} } // namespace ge diff --git a/src/common/graph/format_refiner.h b/src/common/graph/format_refiner.h index eca93bae..fa40a034 100644 --- a/src/common/graph/format_refiner.h +++ b/src/common/graph/format_refiner.h @@ -30,9 +30,10 @@ namespace ge { class FormatRefiner { public: static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); + static void SetInferOrigineFormatFlag(bool is_first = true); private: - static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); + static graphStatus RefreshConstantOutProcess(const OpDescPtr &op_desc); static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector &anchor_points, std::vector &data_nodes, std::unordered_map &node_status); @@ -42,9 +43,8 @@ class FormatRefiner { std::unordered_map &node_status); static graphStatus ForwardInferProcess(std::deque &nodes, ge::NodePtr &node, std::unordered_map &node_status); - static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, - ge::Format data_format, std::unordered_map &node_status); - static bool IsGraphInferred(const ComputeGraphPtr &graph); + static graphStatus DataNodeFormatProcess(std::vector &data_nodes, ge::Format data_format, + std::unordered_map &node_status); }; } // namespace ge #endif // COMMON_GRAPH_FORMAT_REFINER_H_ diff --git a/src/common/graph/ge_attr_define.cc b/src/common/graph/ge_attr_define.cc index 90f1bc6a..96638249 100644 --- a/src/common/graph/ge_attr_define.cc +++ b/src/common/graph/ge_attr_define.cc @@ -725,10 +725,6 @@ const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; const std::string ATTR_MODEL_CORE_TYPE = "core_type"; -const std::string ATTR_MODEL_ATC_VERSION = "atc_version"; - -const std::string ATTR_MODEL_OPP_VERSION = "opp_version"; - // Public attribute const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; @@ -938,7 +934,7 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; const std::string MODEL_ATTR_SESSION_ID = "session_id"; -// lx fusion +// l1 fusion and other fusion in future const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; @@ -952,17 +948,9 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1 const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; -const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; -const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; -const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; -const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; - -// Op debug attrs -const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; -const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; // Atomic addr clean attrs const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; @@ -1027,11 +1015,4 @@ const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; // used for allreduce tailing optimization const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; - -// dynamic shape attr -const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr"; -const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index"; - -// for fusion op plugin -const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; } // namespace ge diff --git a/src/common/graph/ge_tensor.cc b/src/common/graph/ge_tensor.cc index 196b8569..8ffbba91 100644 --- a/src/common/graph/ge_tensor.cc +++ b/src/common/graph/ge_tensor.cc @@ -220,7 +220,6 @@ const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; -const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index"; GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} @@ -568,16 +567,6 @@ DataType GeTensorDesc::GetOriginDataType() const { return TypeUtils::SerialStringToDataType(origin_data_type_str); } -std::vector GeTensorDesc::GetRefPortIndex() const { - vector ref_port_index; - (void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index); - return ref_port_index; -} - -void GeTensorDesc::SetRefPortByIndex(const std::vector &index) { - (void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index); -} - graphStatus GeTensorDesc::IsValid() const { auto dtype = this->GetDataType(); auto format = this->GetFormat(); diff --git a/src/common/graph/graph.cc b/src/common/graph/graph.cc index fc30e9d6..09d4fd56 100644 --- a/src/common/graph/graph.cc +++ b/src/common/graph/graph.cc @@ -210,7 +210,7 @@ class GraphImpl { graphStatus FindOpByName(const string &name, ge::Operator &op) const { auto it = op_list_.find(name); - GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str()); + GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "Error: there is no op: %s.", name.c_str()); op = it->second; return GRAPH_SUCCESS; } diff --git a/src/common/graph/graph.mk b/src/common/graph/graph.mk index 14e8b4b1..5eaf7d86 100644 --- a/src/common/graph/graph.mk +++ b/src/common/graph/graph.mk @@ -77,7 +77,6 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libprotobuf \ libslog \ - liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -95,35 +94,10 @@ LOCAL_CPPFLAGS += -fexceptions LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/tensor.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - - -LOCAL_SHARED_LIBRARIES := - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_HOST_SHARED_LIBRARY) - -#compiler for host -include $(CLEAR_VARS) -LOCAL_MODULE := fwk_stub/libgraph - -LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -LOCAL_CPPFLAGS += -fexceptions - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/attr_value.cc \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - ../../out/graph/lib64/stub/tensor.cc \ + ../../out/atc/lib64/stub/graph.cc \ + ../../out/atc/lib64/stub/operator.cc \ + ../../out/atc/lib64/stub/tensor.cc \ + ../../out/atc/lib64/stub/operator_factory.cc \ LOCAL_SHARED_LIBRARIES := @@ -148,7 +122,6 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libprotobuf \ libslog \ - liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -169,38 +142,10 @@ LOCAL_CFLAGS += -O2 LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/tensor.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - - -LOCAL_SHARED_LIBRARIES := - -LOCAL_LDFLAGS := -lrt -ldl - -ifeq ($(device_os),android) -LOCAL_LDFLAGS := -ldl -endif - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_SHARED_LIBRARY) - -#compiler for device -include $(CLEAR_VARS) -LOCAL_MODULE := fwk_stub/libgraph - -LOCAL_CFLAGS += -O2 - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/attr_value.cc \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - ../../out/graph/lib64/stub/tensor.cc \ + ../../out/atc/lib64/stub/graph.cc \ + ../../out/atc/lib64/stub/operator.cc \ + ../../out/atc/lib64/stub/tensor.cc \ + ../../out/atc/lib64/stub/operator_factory.cc \ LOCAL_SHARED_LIBRARIES := @@ -229,7 +174,6 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libprotobuf \ libslog \ - liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -255,7 +199,6 @@ LOCAL_STATIC_LIBRARIES := \ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libslog \ - liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -279,7 +222,6 @@ LOCAL_STATIC_LIBRARIES := \ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libslog \ - liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl diff --git a/src/common/graph/model_serialize.cc b/src/common/graph/model_serialize.cc index 4bd5769f..19cb4538 100644 --- a/src/common/graph/model_serialize.cc +++ b/src/common/graph/model_serialize.cc @@ -88,8 +88,10 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_ } bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { - GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null."); - GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); + if (op_desc == nullptr || op_def_proto == nullptr) { + GELOGE(GRAPH_FAILED, "Input Para Invalid"); + return false; + } if (op_desc->op_def_.GetProtoMsg() != nullptr) { *op_def_proto = *op_desc->op_def_.GetProtoMsg(); // Delete unnecessary attr @@ -128,17 +130,16 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { op_def_proto->add_subgraph_name(name); } - if (!op_desc->output_name_idx_.empty()) { - proto::AttrDef key; - proto::AttrDef value; - for (auto &item : op_desc->output_name_idx_) { - key.mutable_list()->add_s(item.first); - value.mutable_list()->add_i(item.second); - } - auto op_desc_attr = op_def_proto->mutable_attr(); - op_desc_attr->insert({"_output_name_key", key}); - op_desc_attr->insert({"_output_name_value", value}); + + proto::AttrDef key; + proto::AttrDef value; + for (auto &item : op_desc->output_name_idx_) { + key.mutable_list()->add_s(item.first); + value.mutable_list()->add_i(item.second); } + auto op_desc_attr = op_def_proto->mutable_attr(); + op_desc_attr->insert({"_output_name_key", key}); + op_desc_attr->insert({"_output_name_value", value}); } return true; } diff --git a/src/common/graph/node.cc b/src/common/graph/node.cc index df8efd91..e0939e7e 100644 --- a/src/common/graph/node.cc +++ b/src/common/graph/node.cc @@ -26,7 +26,6 @@ #include "utils/ge_ir_utils.h" #include "utils/node_utils.h" #include "utils/op_desc_utils.h" -#include "common/util/error_manager/error_manager.h" using std::string; using std::vector; @@ -155,7 +154,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(cons const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); if (peer_node == nullptr || r_peer_node == nullptr) { - GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", + GELOGE(GRAPH_FAILED, "Error: anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", this->GetName().c_str(), i, j); return false; } @@ -435,11 +434,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::Get GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const { if (idx < 0 || idx >= static_cast(in_data_anchors_.size())) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19019", {"opname", "index", "anchorname", "optype"}, - {GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()}); - GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx, - GetType().c_str()); + GELOGE(GRAPH_FAILED, "the node doesn't have %d th in_data_anchor, node %s:%s", idx, GetType().c_str(), + GetName().c_str()); return nullptr; } else { return in_data_anchors_[idx]; @@ -449,10 +445,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAn GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const { // Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ if (idx < -1 || idx >= static_cast(in_data_anchors_.size())) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19019", {"opname", "index", "anchorname", "optype"}, - {GetName().c_str(), std::to_string(idx), "in_anchor", GetType().c_str()}); - GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); + GELOGW("the node doesn't have %d th in_anchor, node %s:%s", idx, GetType().c_str(), GetName().c_str()); return nullptr; } else { // Return control anchor @@ -468,15 +461,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int i GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const { // Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ if (idx < -1 || idx >= static_cast(out_data_anchors_.size())) { - ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"}, - { - GetName().c_str(), - std::to_string(idx), - "out_anchor", - GetType().c_str(), - }); - GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx, - GetType().c_str()); + GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_anchor, node %s:%s", idx, GetType().c_str(), + GetName().c_str()); return nullptr; } else { // Return control anchor @@ -491,11 +477,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const { if (idx < 0 || idx >= static_cast(out_data_anchors_.size())) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19019", {"opname", "index", "anchorname", "optype"}, - {GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()}); - GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx, - GetType().c_str()); + GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_data_anchor, node %s:%s", idx, GetType().c_str(), + GetName().c_str()); return nullptr; } else { return out_data_anchors_[idx]; @@ -750,15 +733,11 @@ graphStatus Node::Verify() const { GELOGW("in anchor ptr is null"); continue; } - bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type || - op_->GetType() == const_type || op_->GetType() == variable_type || - op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0; - if (!valid_anchor) { - ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"name", "index"}, - {GetName(), std::to_string(in_anchor_ptr->GetIdx())}); - GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); - return GRAPH_FAILED; - } + GE_CHK_BOOL_RET_STATUS( + op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type || + op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || + in_anchor_ptr->GetPeerAnchors().size() > 0, + GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); } string frameworkop_type = "FrameworkOp"; diff --git a/src/common/graph/op_desc.cc b/src/common/graph/op_desc.cc index e9436a32..adb52162 100644 --- a/src/common/graph/op_desc.cc +++ b/src/common/graph/op_desc.cc @@ -19,7 +19,6 @@ #include "debug/ge_util.h" #include "external/graph/operator.h" #include "framework/common/debug/ge_log.h" -#include "common/util/error_manager/error_manager.h" #include "graph/ge_attr_value.h" #include "graph/ge_tensor.h" #include "graph/operator_factory_impl.h" @@ -471,25 +470,6 @@ GeTensorDesc OpDesc::GetInputDesc(const string &name) const { return *(inputs_desc_[it->second].get()); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); - if (inputs_desc_[index] == nullptr) { - return nullptr; - } - GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid"); - return inputs_desc_[index]; -} - -GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const { - auto input_name_idx = GetAllInputName(); - auto it = input_name_idx.find(name); - if (it == input_name_idx.end()) { - GELOGW("Failed to get [%s] input desc", name.c_str()); - return nullptr; - } - return MutableInputDesc(it->second); -} - GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputNames() const { auto input_name_idx = GetAllInputName(); vector names; @@ -516,6 +496,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(cons GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); + if (inputs_desc_[index] == nullptr) { + return nullptr; + } + GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid"); + return inputs_desc_[index]; +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDesc() const { vector temp{}; for (const auto &it : inputs_desc_) { @@ -620,15 +609,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu return outputs_desc_[index]; } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const { - auto it = output_name_idx_.find(name); - if (it == output_name_idx_.end()) { - GELOGW("Failed to get [%s] output desc", name.c_str()); - return nullptr; - } - return MutableOutputDesc(it->second); -} - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { return static_cast(outputs_desc_.size()); } @@ -902,22 +882,15 @@ graphStatus OpDesc::CommonVerify() const { // Checking shape of all inputs vector ishape = GetInputDescPtr(iname)->GetShape().GetDims(); for (int64_t dim : ishape) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - dim < -2, ErrorManager::GetInstance().ATCReportErrMessage( - "E19014", {"opname", "value", "reason"}, - {GetName(), "input " + iname + " shape", "contains negative or zero dimension"}); - return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), - iname.c_str()); + GE_CHK_BOOL_RET_STATUS(dim >= -2, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", + iname.c_str()); } } // Check all attributes defined const auto &all_attributes = GetAllAttrs(); for (const auto &name : GetAllAttrNames()) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - all_attributes.find(name) == all_attributes.end(), - ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, - {GetName(), "attribute " + name, "is empty"}); - return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str()); + GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, + "operator attribute %s is empty.", name.c_str()); } return GRAPH_SUCCESS; diff --git a/src/common/graph/operator.cc b/src/common/graph/operator.cc index 3a9fd698..1ac8d41d 100644 --- a/src/common/graph/operator.cc +++ b/src/common/graph/operator.cc @@ -36,8 +36,6 @@ #include "graph/op_desc.h" #include "graph/runtime_inference_context.h" #include "graph/usr_types.h" -#include "graph/utils/node_utils.h" -#include "graph/debug/ge_attr_define.h" #include "utils/graph_utils.h" #include "utils/op_desc_utils.h" #include "utils/tensor_adapter.h" @@ -59,7 +57,8 @@ using std::vector; namespace ge { class OpIO { public: - OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {} + explicit OpIO(const string &name, int index, const OperatorImplPtr &owner) + : name_(name), index_(index), owner_(owner) {} ~OpIO() = default; @@ -547,46 +546,56 @@ Operator &Operator::AddControlInput(const Operator &src_oprt) { } graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const { - GE_CHECK_NOTNULL(operator_impl_); - auto node_ptr = operator_impl_->GetNode(); - if (node_ptr != nullptr) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr."); + return GRAPH_FAILED; + } + ge::ConstNodePtr node_ptr = operator_impl_->GetNode(); + if (node_ptr) { // For inner compute graph auto op_desc = node_ptr->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "op_desc is nullptr."); + return GRAPH_FAILED; + } auto index = op_desc->GetInputIndexByName(dst_name); auto in_data_anchor = node_ptr->GetInDataAnchor(index); - GE_CHECK_NOTNULL(in_data_anchor); + if (in_data_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "in_data_anchor is nullptr."); + return GRAPH_FAILED; + } auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(out_data_anchor); - auto peer_node = out_data_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(peer_node); - auto peer_op_desc = peer_node->GetOpDesc(); - GE_CHECK_NOTNULL(peer_op_desc); - auto peer_op_type = peer_op_desc->GetType(); - if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { - auto const_op_impl = ComGraphMakeShared(peer_node); - GE_CHECK_NOTNULL(const_op_impl); - Operator const_op(std::move(const_op_impl)); - return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); - } else if (peer_op_type == DATA) { - auto parent_node = NodeUtils::GetParentInput(peer_node); - while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { - parent_node = NodeUtils::GetParentInput(parent_node); - } - if ((parent_node != nullptr) && - ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { - auto const_op_impl = ComGraphMakeShared(parent_node); - GE_CHECK_NOTNULL(const_op_impl); - Operator const_op(std::move(const_op_impl)); - return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); + if (out_data_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "out_data_anchor is nullptr."); + return GRAPH_FAILED; + } + std::shared_ptr peer_node_ptr = out_data_anchor->GetOwnerNode(); + if (peer_node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "peer_node_ptr is nullptr."); + return GRAPH_FAILED; + } + ge::OperatorImplPtr operator_impl_ptr = nullptr; + operator_impl_ptr = ComGraphMakeShared(peer_node_ptr); + if (operator_impl_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); + return GRAPH_FAILED; + } + Operator const_op(std::move(operator_impl_ptr)); + if (peer_node_ptr->GetOpDesc() != nullptr) { + const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); + if (op_descType == CONSTANTOP) { + return const_op.GetAttr(op::Constant::name_attr_value(), data); + } else if (op_descType == CONSTANT) { + return const_op.GetAttr(op::Const::name_attr_value(), data); } } + // Try get from runtime inference context auto session_id = std::to_string(GetContext().SessionId()); RuntimeInferenceContext *runtime_infer_ctx = nullptr; if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); - auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); + auto ret = runtime_infer_ctx->GetTensor(peer_node_ptr->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); if (ret == GRAPH_SUCCESS) { return GRAPH_SUCCESS; } @@ -595,8 +604,6 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co // For outer graph return GetInputConstDataOut(dst_name, data); } - auto op_name = operator_impl_->GetName(); - GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str()); return GRAPH_FAILED; } graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const { diff --git a/src/common/graph/option/ge_context.cc b/src/common/graph/option/ge_context.cc index f5f5e4c9..f5ebdeee 100644 --- a/src/common/graph/option/ge_context.cc +++ b/src/common/graph/option/ge_context.cc @@ -85,8 +85,6 @@ uint32_t GEContext::DeviceId() { return device_id_; } uint64_t GEContext::TraceId() { return trace_id_; } -void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; } - void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } } // namespace ge diff --git a/src/common/graph/ref_relation.cc b/src/common/graph/ref_relation.cc index 906cb5f9..b3cf37af 100644 --- a/src/common/graph/ref_relation.cc +++ b/src/common/graph/ref_relation.cc @@ -242,10 +242,6 @@ void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &r int sub_graph_idx = 0; for (const auto &name : sub_graph_names) { auto sub_graph = root_graph.GetSubgraph(name); - if (sub_graph == nullptr) { - GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str()); - continue; - } for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { auto sub_graph_node_type = sub_graph_node->GetType(); diff --git a/src/common/graph/shape_refiner.cc b/src/common/graph/shape_refiner.cc index dc1bc541..edf426a5 100644 --- a/src/common/graph/shape_refiner.cc +++ b/src/common/graph/shape_refiner.cc @@ -37,115 +37,6 @@ namespace ge { namespace { -const uint32_t kWhileBodySubGraphIdx = 1; - -graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { - GELOGD("Enter reverse brush while body subgraph process!"); - - auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); - if (sub_graph_body == nullptr) { - GELOGE(GRAPH_FAILED, "Get while body graph failed!"); - return GRAPH_FAILED; - } - - for (const auto &node_sub : sub_graph_body->GetAllNodes()) { - if (node_sub->GetInDataNodes().size() == 0) { - continue; - } - - for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) { - auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i); - (void)input_desc->SetUnknownDimNumShape(); - } - for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) { - auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i); - (void)output_desc->SetUnknownDimNumShape(); - } - } - - return GRAPH_SUCCESS; -} - -graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, - std::vector> &ref_out_tensors) { - GELOGD("Enter update parent node shape for class branch op process"); - // check sub_graph shape.If not same ,do unknown shape process - for (size_t i = 0; i < ref_out_tensors.size(); i++) { - if (ref_out_tensors[i].empty()) { - continue; - } - auto ref_out_tensor = ref_out_tensors[i].at(0); - ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); - for (auto &tensor : ref_out_tensors[i]) { - if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { - GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); - return GRAPH_FAILED; - } - auto shape = tensor.MutableShape(); - if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { - GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, - shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); - ref_out_tensor_shape = GeShape(UNKNOWN_RANK); - break; - } - for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { - if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { - continue; - } - GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, - j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); - (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); - } - } - (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); - } - return GRAPH_SUCCESS; -} - -graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector> &ref_data_tensors, - std::vector> &ref_out_tensors) { - GELOGD("Enter update parent node shape for class while op process"); - if (ref_data_tensors.size() != ref_out_tensors.size()) { - GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(), - ref_data_tensors.size(), ref_out_tensors.size()); - return GRAPH_FAILED; - } - for (size_t i = 0; i < ref_data_tensors.size(); i++) { - if (ref_out_tensors[i].size() != 1) { - GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!"); - return GRAPH_FAILED; - } - } - bool is_need_reverse_brush = false; - // check input and output - for (size_t i = 0; i < ref_out_tensors.size(); i++) { - if (ref_out_tensors[i].empty()) { - continue; - } - auto ref_out_tensor = ref_out_tensors[i].at(0); - auto tmp_shape = ref_out_tensor.MutableShape(); - // ref_i's data and output tensor shape should be same - for (auto &tensor : ref_data_tensors[i]) { - if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { - GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str()); - return GRAPH_FAILED; - } - auto shape = tensor.MutableShape(); - if (shape.GetDims() != tmp_shape.GetDims()) { - ref_out_tensor.SetUnknownDimNumShape(); - is_need_reverse_brush = true; - break; - } - } - (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); - } - // reverse refresh while body shape - if (is_need_reverse_brush) { - return ReverseBrushWhileBodySubGraph(node); - } - return GRAPH_SUCCESS; -} - graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { auto op_desc = node->GetOpDesc(); auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); @@ -207,37 +98,6 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { } return GRAPH_SUCCESS; } - -graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr &sub_graph, NodePtr &netoutput, - const ConstNodePtr &node, - std::vector> &ref_data_tensors) { - auto sub_nodes = sub_graph->GetDirectNode(); - for (size_t i = sub_nodes.size(); i > 0; --i) { - auto sub_node = sub_nodes.at(i - 1); - if (sub_node->GetType() == NETOUTPUT) { - netoutput = sub_node; - } - if (sub_node->GetType() == DATA) { - if (sub_node->GetOpDesc() == nullptr) { - return GRAPH_FAILED; - } - - int ref_i; - if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { - GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); - return GRAPH_FAILED; - } - if (ref_i < 0 || static_cast(ref_i) >= node->GetAllInDataAnchorsSize()) { - GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(), - ref_i, node->GetAllInDataAnchorsSize()); - return GRAPH_FAILED; - } - ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); - } - } - return GRAPH_SUCCESS; -} - graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { auto op_desc = node->GetOpDesc(); auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); @@ -245,10 +105,7 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { return GRAPH_SUCCESS; } - std::vector> ref_data_tensors(node->GetAllInDataAnchorsSize()); - std::vector> ref_out_tensors(node->GetAllOutDataAnchorsSize()); auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); - for (const auto &name : sub_graph_names) { if (name.empty()) { GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); @@ -260,9 +117,13 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { return GRAPH_FAILED; } NodePtr netoutput = nullptr; - auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); - if (ret != GRAPH_SUCCESS) { - return ret; + auto sub_nodes = sub_graph->GetDirectNode(); + for (size_t i = sub_nodes.size(); i > 0; --i) { + auto sub_node = sub_nodes.at(i - 1); + if (sub_node->GetType() == NETOUTPUT) { + netoutput = sub_node; + break; + } } if (netoutput == nullptr) { GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); @@ -289,17 +150,19 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { continue; } GELOGI("Parent node index of edge desc is %d", ref_i); - if (ref_i < 0 || static_cast(ref_i) >= node->GetAllOutDataAnchorsSize()) { + auto output_desc = op_desc->MutableOutputDesc(static_cast(ref_i)); + if (output_desc == nullptr) { + GE_LOGE( + "The ref index(%d) on the input %d of netoutput %s on the sub graph %s " + "parent node %s are incompatible, outputs num %u", + ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(), + node->GetAllOutDataAnchorsSize()); return GRAPH_FAILED; } - ref_out_tensors[ref_i].emplace_back(*edge_desc); + op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc); } } - - if (node->GetType() == WHILE) { - return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); - } - return UpdateParentNodeForBranch(node, ref_out_tensors); + return GRAPH_SUCCESS; } } // namespace void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { @@ -307,9 +170,6 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str GELOGE(GRAPH_FAILED, "node is null"); return; } - if (!IsLogEnable(GE, DLOG_DEBUG)) { - return; - } ge::OpDescPtr op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); std::string str; diff --git a/src/common/graph/stub/Makefile b/src/common/graph/stub/Makefile new file mode 100644 index 00000000..832adcd5 --- /dev/null +++ b/src/common/graph/stub/Makefile @@ -0,0 +1,6 @@ +inc_path := $(shell pwd)/inc/external/ +out_path := $(shell pwd)/out/atc/lib64/stub/ +stub_path := $(shell pwd)/common/graph/stub/ + +mkdir_stub := $(shell mkdir -p $(out_path)) +graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) diff --git a/src/common/graph/stub/gen_stubapi.py b/src/common/graph/stub/gen_stubapi.py new file mode 100644 index 00000000..6185c479 --- /dev/null +++ b/src/common/graph/stub/gen_stubapi.py @@ -0,0 +1,573 @@ +import os +import re +import sys +import logging + +logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', + level=logging.INFO) + +""" + this attr is used for symbol table visible +""" +GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' + +""" + generate stub func body by return type +""" +RETURN_STATEMENTS = { + 'graphStatus': ' return GRAPH_SUCCESS;', + 'Status': ' return SUCCESS;', + 'Graph': ' return Graph();', + 'Graph&': ' return *this;', + 'Format': ' return Format();', + 'Format&': ' return *this;', + 'Shape': ' return Shape();', + 'Shape&': ' return *this;', + 'TensorDesc': ' return TensorDesc();', + 'TensorDesc&': ' return *this;', + 'Tensor': ' return Tensor();', + 'Tensor&': ' return *this;', + 'Operator': ' return Operator();', + 'Operator&': ' return *this;', + 'Ptr': ' return nullptr;', + 'std::string': ' return "";', + 'std::string&': ' return "";', + 'string': ' return "";', + 'int': ' return 0;', + 'DataType': ' return DT_FLOAT;', + 'InferenceContextPtr': ' return nullptr;', + 'SubgraphBuilder': ' return nullptr;', + 'OperatorImplPtr': ' return nullptr;', + 'OutHandler': ' return nullptr;', + 'std::vector': ' return {};', + 'std::vector': ' return {};', + 'std::map': ' return {};', + 'uint32_t': ' return 0;', + 'int64_t': ' return 0;', + 'uint64_t': ' return 0;', + 'size_t': ' return 0;', + 'float': ' return 0.0f;', + 'bool': ' return false;', +} + +""" + max code len per line in hua_wei software programming specifications +""" +max_code_len_per_line = 100 + +""" + white_list_for_debug, include_dir_key_words is to + determines which header files to generate cc files from + when DEBUG on +""" +white_list_for_debug = ["operator.h", "tensor.h", + "graph.h", "operator_factory.h", + "ge_ir_build.h"] +include_dir_key_words = ["ge", "graph"] +DEBUG = True + + +def need_generate_func(func_line): + """ + :param func_line: + :return: + """ + if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ + or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): + return False + return True + + +def file_endswith_white_list_suffix(file): + """ + :param file: + :return: + """ + if DEBUG: + for suffix in white_list_for_debug: + if file.endswith(suffix): + return True + return False + else: + return True + + +""" + belows are patterns used for analyse .h file +""" +# pattern function +pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after +([a-zA-Z~_] # void int likely +.* +[)] #we find ) +(?!.*{) # we do not want the case int abc() const { return 1;} +.*) +(;.*) #we want to find ; and after for we will replace these later +\n$ +""", re.VERBOSE | re.MULTILINE | re.DOTALL) + +# pattern comment +pattern_comment = re.compile(r'^\s*//') +pattern_comment_2_start = re.compile(r'^\s*/[*]') +pattern_comment_2_end = re.compile(r'[*]/\s*$') +# pattern define +pattern_define = re.compile(r'^\s*#define') +pattern_define_return = re.compile(r'\\\s*$') +# blank line +pattern_blank_line = re.compile(r'^\s*$') +# virtual,explicit,friend,static +pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') +# lead space +pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') +# functions will have patterns such as func ( or func( +# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist +# format like :"operator = ()" +pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') +# template +pattern_template = re.compile(r'^\s*template') +pattern_template_end = re.compile(r'>\s*$') +# namespace +pattern_namespace = re.compile(r'namespace.*{') +# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with +pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+ 0 and not friend_match: + line, func_name = self.handle_class_member_func(line, template_string) + # Normal functions + else: + line, func_name = self.handle_normal_func(line, template_string) + + need_generate = need_generate_func(line) + # func body + line += self.implement_function(line) + # comment + line = self.gen_comment(start_i) + line + # write to out file + self.write_func_content(line, func_name, need_generate) + # next loop + self.line_index += 1 + + logging.info('Added %s functions', len(self.func_list_exist)) + logging.info('Successfully converted,please see ' + self.output_file) + + def handle_func1(self, line): + """ + :param line: + :return: + """ + find1 = re.search('[(]', line) + if not find1: + self.line_index += 1 + return "continue", line, None + find2 = re.search('[)]', line) + start_i = self.line_index + space_match = pattern_leading_space.search(line) + # deal with + # int abc(int a, + # int b) + if find1 and (not find2): + self.line_index += 1 + line2 = self.input_content[self.line_index] + if space_match: + line2 = re.sub('^' + space_match.group(1), '', line2) + line += line2 + while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): + self.line_index += 1 + line2 = self.input_content[self.line_index] + line2 = re.sub('^' + space_match.group(1), '', line2) + line += line2 + + match_start = pattern_start.search(self.input_content[self.line_index]) + match_end = pattern_end.search(self.input_content[self.line_index]) + if match_start: # like ) { or ) {} int the last line + if not match_end: + self.stack.append('normal_now') + ii = start_i + while ii <= self.line_index: + ii += 1 + self.line_index += 1 + return "continue", line, start_i + logging.info("line[%s]", line) + # ' int abc();'->'int abc()' + (line, match) = pattern_func.subn(r'\2\n', line) + logging.info("line[%s]", line) + # deal with case: + # 'int \n abc(int a, int b)' + if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): + line = self.input_content[start_i - 1] + line + line = line.lstrip() + if not match: + self.line_index += 1 + return "continue", line, start_i + return "pass", line, start_i + + def handle_stack(self, match_start): + """ + :param match_start: + :return: + """ + line = self.input_content[self.line_index] + match_end = pattern_end.search(line) + if match_start: + self.stack.append('normal_now') + if match_end: + top_status = self.stack.pop() + if top_status == 'namespace_now': + self.output_fd.write(line + '\n') + elif top_status == 'class_now': + self.stack_class.pop() + self.stack_template.pop() + if match_start or match_end: + self.line_index += 1 + return "continue" + + if len(self.stack) > 0 and self.stack[-1] == 'normal_now': + self.line_index += 1 + return "continue" + return "pass" + + def handle_class(self, template_string, line, match_start, match_class): + """ + :param template_string: + :param line: + :param match_start: + :param match_class: + :return: + """ + if match_class: # we face a class + self.stack_template.append(template_string) + self.stack.append('class_now') + class_name = match_class.group(3) + + # class template specializations: class A > + if '<' in class_name: + k = line.index('<') + fit = 1 + for ii in range(k + 1, len(line)): + if line[ii] == '<': + fit += 1 + if line[ii] == '>': + fit -= 1 + if fit == 0: + break + class_name += line[k + 1:ii + 1] + logging.info('class_name[%s]', class_name) + self.stack_class.append(class_name) + while not match_start: + self.line_index += 1 + line = self.input_content[self.line_index] + match_start = pattern_start.search(line) + self.line_index += 1 + return "continue" + return "pass" + + def handle_template(self): + line = self.input_content[self.line_index] + match_template = pattern_template.search(line) + template_string = '' + if match_template: + match_template_end = pattern_template_end.search(line) + template_string = line + while not match_template_end: + self.line_index += 1 + line = self.input_content[self.line_index] + template_string += line + match_template_end = pattern_template_end.search(line) + self.line_index += 1 + return template_string + + def handle_namespace(self): + line = self.input_content[self.line_index] + match_namespace = pattern_namespace.search(line) + if match_namespace: # we face namespace + self.output_fd.write(line + '\n') + self.stack.append('namespace_now') + self.line_index += 1 + + def handle_normal_func(self, line, template_string): + template_line = '' + self.stack_template.append(template_string) + if self.stack_template[-1] != '': + template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) + # change '< class T = a, class U = A(3)>' to '' + template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) + template_line = re.sub(r'\s*=.*,', ',', template_line) + template_line = re.sub(r'\s*=.*', '', template_line) + line = re.sub(r'\s*=.*,', ',', line) + line = re.sub(r'\s*=.*\)', ')', line) + line = template_line + line + self.stack_template.pop() + func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() + logging.info("line[%s]", line) + logging.info("func_name[%s]", func_name) + return line, func_name + + def handle_class_member_func(self, line, template_string): + template_line = '' + x = '' + if template_string != '': + template_string = re.sub(r'\s*template', 'template', template_string) + template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) + template_string = re.sub(r'\s*=.*,', ',', template_string) + template_string = re.sub(r'\s*=.*', '', template_string) + if self.stack_template[-1] != '': + if not (re.search(r'<\s*>', stack_template[-1])): + template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) + if not (re.search(r'<.*>', self.stack_class[-1])): + # for x we get like template -> + x = re.sub(r'template\s*<', '<', template_line) # remove template -> + x = re.sub(r'\n', '', x) + x = re.sub(r'\s*=.*,', ',', x) + x = re.sub(r'\s*=.*\>', '>', x) + x = x.rstrip() # remove \n + x = re.sub(r'(class|typename)\s+|(|\s*class)', '', + x) # remove class,typename -> + x = re.sub(r'<\s+', '<', x) + x = re.sub(r'\s+>', '>', x) + x = re.sub(r'\s+,', ',', x) + x = re.sub(r',\s+', ', ', x) + line = re.sub(r'\s*=\s+0', '', line) + line = re.sub(r'\s*=\s+.*,', ',', line) + line = re.sub(r'\s*=\s+.*\)', ')', line) + logging.info("x[%s]\nline[%s]", x, line) + # if the function is long, void ABC::foo() + # breaks into two lines void ABC::\n foo() + temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1) + if len(temp_line) > max_code_len_per_line: + line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1) + else: + line = temp_line + logging.info("line[%s]", line) + # add template as the above if there is one + template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) + template_line = re.sub(r'\s*=.*,', ',', template_line) + template_line = re.sub(r'\s*=.*', '', template_line) + line = template_line + template_string + line + func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() + logging.info("line[%s]", line) + logging.info("func_name[%s]", func_name) + return line, func_name + + def write_func_content(self, content, func_name, need_generate): + if not (func_name in self.func_list_exist) and need_generate: + self.output_fd.write(content) + self.func_list_exist.append(func_name) + logging.info('add func:[%s]', func_name) + + def gen_comment(self, start_i): + comment_line = '' + # Function comments are on top of function declarations, copy them over + k = start_i - 1 # one line before this func start + if pattern_template.search(self.input_content[k]): + k -= 1 + if pattern_comment_2_end.search(self.input_content[k]): + comment_line = self.input_content[k].lstrip() + while not pattern_comment_2_start.search(self.input_content[k]): + k -= 1 + comment_line = self.input_content[k].lstrip() + comment_line + else: + for j in range(k, 0, -1): + c_line = self.input_content[j] + if pattern_comment.search(c_line): + c_line = re.sub(r'\s*//', '//', c_line) + comment_line = c_line + comment_line + else: + break + return comment_line + + @staticmethod + def implement_function(func): + function_def = '' + function_def += '{\n' + + all_items = func.split() + start = 0 + return_type = all_items[start] + if return_type == "const": + start += 1 + return_type = all_items[start] + if return_type.startswith(('std::map', 'std::set', 'std::vector')): + return_type = "std::map" + if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): + return_type = "Ptr" + if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): + return_type += "&" + if RETURN_STATEMENTS.__contains__(return_type): + function_def += RETURN_STATEMENTS[return_type] + else: + logging.warning("Unhandled return type[%s]", return_type) + + function_def += '\n' + function_def += '}\n' + function_def += '\n' + return function_def + + +def collect_header_files(path): + """ + :param path: + :return: + """ + header_files = [] + shared_includes_content = [] + for root, dirs, files in os.walk(path): + files.sort() + for file in files: + if file.find("git") >= 0: + continue + if not file.endswith('.h'): + continue + file_path = os.path.join(root, file) + file_path = file_path.replace('\\', '/') + header_files.append(file_path) + include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) + shared_includes_content.append(include_str) + return header_files, shared_includes_content + + +def generate_stub_file(inc_dir, out_cc_dir): + """ + :param inc_dir: + :param out_cc_dir: + :return: + """ + target_header_files, shared_includes_content = collect_header_files(inc_dir) + for header_file in target_header_files: + if not file_endswith_white_list_suffix(header_file): + continue + cc_file = re.sub('.h*$', '.cc', header_file) + h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) + h_2_cc.h2cc() + + +def gen_code(inc_dir, out_cc_dir): + """ + :param inc_dir: + :param out_cc_dir: + :return: + """ + if not inc_dir.endswith('/'): + inc_dir += '/' + if not out_cc_dir.endswith('/'): + out_cc_dir += '/' + for include_dir_key_word in include_dir_key_words: + generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) + + +if __name__ == '__main__': + inc_dir = sys.argv[1] + out_cc_dir = sys.argv[2] + gen_code(inc_dir, out_cc_dir) diff --git a/src/common/graph/utils/ge_ir_utils.h b/src/common/graph/utils/ge_ir_utils.h index b572ab38..9b16be18 100644 --- a/src/common/graph/utils/ge_ir_utils.h +++ b/src/common/graph/utils/ge_ir_utils.h @@ -1,18 +1,18 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd - + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - + * * http://www.apache.org/licenses/LICENSE-2.0 - + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. -*/ + */ #ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ #define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ diff --git a/src/common/graph/utils/graph_utils.cc b/src/common/graph/utils/graph_utils.cc index a6980358..ca2ebcdc 100644 --- a/src/common/graph/utils/graph_utils.cc +++ b/src/common/graph/utils/graph_utils.cc @@ -38,7 +38,6 @@ #include "utils/ge_ir_utils.h" #include "utils/node_utils.h" #include "debug/ge_op_types.h" -#include "external/ge/ge_api_types.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" @@ -411,8 +410,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTra /// @return graphStatus /// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, - const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { +GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { GE_CHECK_NOTNULL(src); GE_CHECK_NOTNULL(insert_node); @@ -571,7 +570,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons static int max_dumpfile_num = 0; if (max_dumpfile_num == 0) { string opt = "0"; - (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); + (void)GetContext().GetOption("ge.maxDumpFileNum", opt); max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); } if (max_dumpfile_num != 0 && file_idx > max_dumpfile_num) { @@ -671,7 +670,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText if (maxDumpFileSize == 0) { string opt = "0"; // Can not check return value - (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); + (void)GetContext().GetOption("ge.maxDumpFileSize", opt); maxDumpFileSize = atol(opt.c_str()); } if (maxDumpFileSize != 0 && fileSize != -1 && fileSize > maxDumpFileSize) { @@ -741,7 +740,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn static int max_dumpfile_num = 0; if (max_dumpfile_num == 0) { string opt = "0"; - (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); + (void)GetContext().GetOption("ge.maxDumpFileNum", opt); max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); } if (max_dumpfile_num != 0 && file_index > max_dumpfile_num) { @@ -921,7 +920,7 @@ graphStatus RelinkDataIO(const NodePtr &node, const std::vector &io_map, In InNodesToOut GetFullConnectIONodes(const NodePtr &node) { InNodesToOut in_nodes_to_out; if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Node is nullptr"); + GELOGE(GRAPH_FAILED, "Node is nullptr,node is %s", node->GetName().c_str()); return in_nodes_to_out; } auto in_nodes_list = node->GetInNodes(); @@ -1309,36 +1308,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCt return GRAPH_SUCCESS; } -/// -/// Copy all in-data edges from `src_node` to `dst_node`. -/// @param src_node -/// @param dst_node -/// @return -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node, - NodePtr &dst_node) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - auto src_data_in_nodes = src_node->GetInDataNodes(); - if (src_data_in_nodes.empty()) { - return GRAPH_SUCCESS; - } - for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { - auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); - auto ret = - GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", - in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), - src_node->GetName().c_str(), dst_node->GetName().c_str()); - return ret; - } - } - return GRAPH_SUCCESS; -} - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node) { if (graph->AddInputNode(node) == nullptr) { @@ -1370,7 +1339,7 @@ graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, std::map> &symbol_to_anchors, std::map &anchor_to_symbol) { GE_CHECK_NOTNULL(graph); - for (const auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetAllNodes()) { // in_data_anchor if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); @@ -1427,16 +1396,16 @@ graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); } - const std::string &type = node->GetType(); + std::string type = node->GetType(); if ((type == MERGE) || (type == STREAMMERGE)) { return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); } - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + for (auto &in_data_anchor : node->GetAllInDataAnchors()) { NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { - const std::string &symbol = cur_node_info.ToString(); + std::string symbol = cur_node_info.ToString(); GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); symbol_to_anchors[symbol] = {cur_node_info}; anchor_to_symbol[symbol] = symbol; @@ -1463,7 +1432,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, std::map> &symbol_to_anchors, std::map &anchor_to_symbol) { GE_CHECK_NOTNULL(node); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut); if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { continue; @@ -1477,7 +1446,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, return GRAPH_FAILED; } } else { - const std::string &symbol = cur_node_info.ToString(); + std::string symbol = cur_node_info.ToString(); GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); symbol_to_anchors.emplace(std::make_pair(symbol, std::list{cur_node_info})); anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); @@ -1537,7 +1506,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, GE_CHECK_NOTNULL(node); std::vector exist_node_infos; std::vector cur_node_infos; - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + for (auto &in_data_anchor : node->GetAllInDataAnchors()) { auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { std::string next_name; @@ -1560,10 +1529,10 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, size_t anchor_nums = 0; NodeIndexIO max_node_index_io(nullptr, 0, kOut); - for (const auto &temp_node_info : exist_node_infos) { + for (auto &temp_node_info : exist_node_infos) { auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); if (iter1 != anchor_to_symbol.end()) { - const std::string &temp_symbol = iter1->second; + std::string temp_symbol = iter1->second; auto iter2 = symbol_to_anchors.find(temp_symbol); if (iter2 != symbol_to_anchors.end()) { if (iter2->second.size() > anchor_nums) { @@ -1575,7 +1544,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, } std::string symbol; - for (const auto &temp_node_info : exist_node_infos) { + for (auto &temp_node_info : exist_node_infos) { if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != GRAPH_SUCCESS) || symbol.empty()) { @@ -1587,7 +1556,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, auto iter = symbol_to_anchors.find(symbol); if (iter != symbol_to_anchors.end()) { - for (const auto &temp_node_info : cur_node_infos) { + for (auto &temp_node_info : cur_node_infos) { GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); iter->second.emplace_back(temp_node_info); anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); @@ -1615,7 +1584,7 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + for (auto &in_data_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out_anchor); @@ -1658,8 +1627,8 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, std::map> &symbol_to_anchors, std::map &anchor_to_symbol, std::string &symbol) { - const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; - const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; + std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; + std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; if (symbol1 == symbol2) { symbol = symbol1; GELOGI("no need to union."); @@ -1715,7 +1684,7 @@ graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const return GRAPH_FAILED; } - const std::string &symbol = iter1->second; + std::string symbol = iter1->second; auto iter2 = symbol_to_anchors.find(symbol); if (iter2 == symbol_to_anchors.end()) { GE_LOGE("symbol %s not found.", symbol.c_str()); @@ -1743,7 +1712,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t // pass-through op NodePtr node = out_data_anchor->GetOwnerNode(); - const std::string &type = node->GetType(); + std::string type = node->GetType(); const std::set pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { reuse_in_index = output_index; @@ -1786,7 +1755,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t uint32_t reuse_input_index = 0; if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { reuse_in_index = static_cast(reuse_input_index); - GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index, + GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index, reuse_in_index); return true; } @@ -2328,7 +2297,7 @@ void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string & return; } - std::string name = node->GetName() + "_RetVal_" + std::to_string(index); + std::string name = node->GetName() + "_RetVal"; OpDescPtr ret_val_desc = shared_ptr(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); if (ret_val_desc == nullptr) { error_code = GRAPH_FAILED; diff --git a/src/common/graph/utils/node_utils.cc b/src/common/graph/utils/node_utils.cc index 20bcacfb..e4fb8b82 100644 --- a/src/common/graph/utils/node_utils.cc +++ b/src/common/graph/utils/node_utils.cc @@ -296,18 +296,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer return GRAPH_FAILED; } for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { - auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); - ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); - bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); - if (!is_unknown_graph) { - output_tensor->SetOriginShape(output_tensor->GetShape()); - output_tensor->SetOriginDataType(output_tensor->GetDataType()); - } + GeTensorDesc output_tensor = op_desc->GetOutputDesc(out_anchor->GetIdx()); + ge::TensorUtils::SetRealDimCnt(output_tensor, static_cast(output_tensor.GetShape().GetDims().size())); + output_tensor.SetOriginShape(output_tensor.GetShape()); + output_tensor.SetOriginDataType(output_tensor.GetDataType()); GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", - node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), - TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), - TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); - + node_ptr->GetName().c_str(), output_tensor.GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor.GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor.GetOriginDataType()).c_str()); + (void)op_desc->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor); for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); @@ -319,17 +316,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer continue; } GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", - peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), - output_tensor->GetDataType(), output_tensor->GetOriginDataType()); - peer_input_desc->SetShape(output_tensor->GetShape()); - peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); - peer_input_desc->SetDataType(output_tensor->GetDataType()); - peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); + peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(), + output_tensor.GetDataType(), output_tensor.GetOriginDataType()); + peer_input_desc->SetShape(output_tensor.GetShape()); + peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); + peer_input_desc->SetDataType(output_tensor.GetDataType()); + peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); std::vector> shape_range; - (void)output_tensor->GetShapeRange(shape_range); + (void)output_tensor.GetShapeRange(shape_range); peer_input_desc->SetShapeRange(shape_range); ge::TensorUtils::SetRealDimCnt(*peer_input_desc, - static_cast(output_tensor->GetShape().GetDims().size())); + static_cast(output_tensor.GetShape().GetDims().size())); GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); @@ -404,13 +401,10 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { auto desc = node.GetOpDesc(); GE_CHECK_NOTNULL(desc); - // check self - is_unknow = OpShapeIsUnknown(desc); - if (is_unknow) { - return GRAPH_SUCCESS; - } + auto sub_graph_names = desc->GetSubgraphInstanceNames(); if (sub_graph_names.empty()) { + is_unknow = OpShapeIsUnknown(desc); return GRAPH_SUCCESS; } else { auto owner_graph = node.GetOwnerComputeGraph(); @@ -561,53 +555,6 @@ NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return peer_out_anchor->GetOwnerNode(); } -/// -/// @brief Check is varying_input for while node -/// @param [in] node: Data node for subgraph -/// @return bool -/// -bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { - if (node == nullptr) { - return false; - } - if (node->GetType() != DATA) { - return false; // not input_node for subgraph - } - - const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); - if (parent_node == nullptr) { - return false; // root graph - } - - if (kWhileOpTypes.count(parent_node->GetType()) == 0) { - return false; // not input_node for while subgraph - } - - uint32_t index_i = 0; - if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { - GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); - return false; - } - bool varying_flag = true; - for (const auto &item : node->GetOutDataNodesAndAnchors()) { - if (item.first->GetType() != NETOUTPUT) { - continue; - } - OpDescPtr op_desc = item.first->GetOpDesc(); - uint32_t index_o = 0; - if ((op_desc == nullptr) || - !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { - continue; // input for while-cond subgraph - } - if (index_i != index_o) { - continue; // varying input for while-body subgraph - } - varying_flag = false; - break; - } - return varying_flag; -} - /// /// @brief Get subgraph input is constant. /// @param [in] node @@ -690,86 +637,4 @@ Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { return GRAPH_SUCCESS; } -/// -/// @brief Get subgraph input data node by index. -/// @param [in] node -/// @return Node -/// -vector NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { - vector in_data_node_vec; - auto op_desc = node.GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); - auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - if (subgraph_names.empty()) { - GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); - return in_data_node_vec; - } - auto compute_graph = node.GetOwnerComputeGraph(); - for (const std::string &instance_name : subgraph_names) { - auto subgraph = compute_graph->GetSubgraph(instance_name); - for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { - int parent_index = 0; - if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { - (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); - if (parent_index == index) { - in_data_node_vec.emplace_back(node_in_subgraph); - } - } - } - } - return in_data_node_vec; -} -/// -/// @brief Get subgraph input data node by index. -/// @param [in] node -/// @return Node -/// -vector NodeUtils::GetSubgraphOutputNodes(const Node &node) { - vector out_data_node_vec; - auto op_desc = node.GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); - auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - if (subgraph_names.empty()) { - GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); - return out_data_node_vec; - } - auto compute_graph = node.GetOwnerComputeGraph(); - for (const std::string &instance_name : subgraph_names) { - auto subgraph = compute_graph->GetSubgraph(instance_name); - for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { - if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { - out_data_node_vec.emplace_back(node_in_subgraph); - } - } - } - return out_data_node_vec; -} - -NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) { - if (node.GetInDataAnchor(index) == nullptr) { - return nullptr; - } - if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { - return nullptr; - } - return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); -} - -vector NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) { - vector out_data_nodes; - auto out_data_anchor = node.GetOutDataAnchor(index); - if (out_data_anchor == nullptr) { - return out_data_nodes; - } - for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - if (peer_in_anchor == nullptr) { - continue; - } - if (peer_in_anchor->GetOwnerNode() == nullptr) { - continue; - } - out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); - } - return out_data_nodes; -} } // namespace ge diff --git a/src/common/graph/utils/op_desc_utils.cc b/src/common/graph/utils/op_desc_utils.cc index c5de264f..6264ddb9 100644 --- a/src/common/graph/utils/op_desc_utils.cc +++ b/src/common/graph/utils/op_desc_utils.cc @@ -197,33 +197,24 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils:: continue; } auto in_node = out_anchor->GetOwnerNode(); - while (true) { - if (in_node == nullptr) { - break; + if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { + ret.push_back(in_node); + } else if (in_node->GetType() == DATA) { + const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); + GE_CHK_BOOL_EXEC(graph != nullptr, continue, "Owner graph is null"); + + const NodePtr &parent_node = graph->GetParentNode(); + if (parent_node == nullptr) { + continue; // Root graph. } - if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { - ret.push_back(in_node); - break; - } else if (in_node->GetType() == DATA) { - if (NodeUtils::IsWhileVaryingInput(in_node)) { - break; - } - in_node = NodeUtils::GetParentInput(in_node); - } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { - bool is_constant = false; - (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); - if (!is_constant) { - break; - } - // Enter node has and only has one input - if (in_node->GetInDataNodes().size() != 1) { - GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(), - in_node->GetInDataNodes().size()); - break; - } - in_node = in_node->GetInDataNodes().at(0); - } else { - break; + + if (kWhileOpTypes.count(parent_node->GetType()) > 0) { + continue; // Subgraph of While cond or body. + } + + NodePtr input_node = NodeUtils::GetParentInput(in_node); + if ((input_node != nullptr) && ((input_node->GetType() == CONSTANT) || (input_node->GetType() == CONSTANTOP))) { + ret.push_back(input_node); } } } @@ -444,27 +435,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils:: GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::MutableWeights(const ge::Node &node) { vector ret; - auto op_desc = node.GetOpDesc(); - GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!"); - // Place holder operator, try to get the weight from parent node - // when parent node is const operator - if (node.GetType() == PLACEHOLDER) { - std::string parent_op; - (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op); - // This if judgment is necessary because the current subgraph optimization is multithreaded - // and the parent node of the PLD operation should be a stable type, such as const - if (parent_op == CONSTANT || parent_op == CONSTANTOP) { - NodePtr parent_node = nullptr; - parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); - if (parent_node != nullptr) { - op_desc = parent_node->GetOpDesc(); - GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str()); - } - } - } + GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return ret, "node.GetOpDesc is nullptr!"); // Const operator, take the weight directly - if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) { - auto weight = MutableWeights(op_desc); + if (node.GetOpDesc()->GetType() == CONSTANT || (node.GetOpDesc()->GetType() == CONSTANTOP)) { + auto weight = MutableWeights(node.GetOpDesc()); if (weight == nullptr) { GELOGI("const op has no weight, op name:%s", node.GetName().c_str()); return ret; diff --git a/src/common/graph/utils/tensor_utils.cc b/src/common/graph/utils/tensor_utils.cc index 26ac8cc8..674cab55 100644 --- a/src/common/graph/utils/tensor_utils.cc +++ b/src/common/graph/utils/tensor_utils.cc @@ -19,7 +19,6 @@ #include "debug/ge_log.h" #include "framework/common/debug/ge_log.h" -#include "common/util/error_manager/error_manager.h" #include "graph/ge_tensor.h" #include "graph/types.h" #include "graph/utils/type_utils.h" @@ -106,10 +105,7 @@ static graphStatus CalcElementCntByDims(const std::vector &dims, int64_ element_cnt = 1; for (int64_t dim : dims) { if (CheckMultiplyOverflowInt64(element_cnt, dim)) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19013", {"function", "var1", "var2"}, - {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)}); - GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim); + GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, as when multiplying %ld and %ld.", element_cnt, dim); return GRAPH_FAILED; } element_cnt *= dim; @@ -277,6 +273,7 @@ static graphStatus CalcTensorElementCnt(const std::vector &dims, Format case FORMAT_FRACTAL_Z: graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt); break; + case FORMAT_NC1HWC0_C04: case FORMAT_FRACTAL_NZ: case FORMAT_FRACTAL_ZZ: case FORMAT_NDHWC: @@ -288,7 +285,6 @@ static graphStatus CalcTensorElementCnt(const std::vector &dims, Format case FORMAT_NDC1HWC0: case FORMAT_FRACTAL_Z_C04: case FORMAT_FRACTAL_ZN_LSTM: - case FORMAT_NC1HWC0_C04: graph_status = CalcElementCntByDims(dims, element_cnt); break; default: diff --git a/src/common/graph/utils/type_utils.cc b/src/common/graph/utils/type_utils.cc index 5215b141..e4986931 100644 --- a/src/common/graph/utils/type_utils.cc +++ b/src/common/graph/utils/type_utils.cc @@ -147,8 +147,7 @@ static const std::map kStringToFormatMap = { {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, {"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G}, {"FORMAT_RESERVED", FORMAT_RESERVED}, - {"ALL", FORMAT_ALL}, - {"NULL", FORMAT_NULL}}; + {"ALL", FORMAT_ALL}}; static const std::map kDataTypeToStringMap = { {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set.