Browse Source

update common/graph, modify libgraph.so

tags/v0.7.0-beta
zhangzhenghai 4 years ago
parent
commit
ea410fb318
24 changed files with 779 additions and 737 deletions
  1. +0
    -1
      src/common/graph/CMakeLists.txt
  2. +0
    -13
      src/common/graph/compute_graph.cc
  3. +0
    -4
      src/common/graph/debug/ge_op_types.h
  4. +21
    -69
      src/common/graph/format_refiner.cc
  5. +4
    -4
      src/common/graph/format_refiner.h
  6. +1
    -20
      src/common/graph/ge_attr_define.cc
  7. +0
    -11
      src/common/graph/ge_tensor.cc
  8. +1
    -1
      src/common/graph/graph.cc
  9. +8
    -66
      src/common/graph/graph.mk
  10. +13
    -12
      src/common/graph/model_serialize.cc
  11. +13
    -34
      src/common/graph/node.cc
  12. +13
    -40
      src/common/graph/op_desc.cc
  13. +40
    -33
      src/common/graph/operator.cc
  14. +0
    -2
      src/common/graph/option/ge_context.cc
  15. +0
    -4
      src/common/graph/ref_relation.cc
  16. +16
    -156
      src/common/graph/shape_refiner.cc
  17. +6
    -0
      src/common/graph/stub/Makefile
  18. +573
    -0
      src/common/graph/stub/gen_stubapi.py
  19. +5
    -5
      src/common/graph/utils/ge_ir_utils.h
  20. +24
    -55
      src/common/graph/utils/graph_utils.cc
  21. +18
    -153
      src/common/graph/utils/node_utils.cc
  22. +20
    -46
      src/common/graph/utils/op_desc_utils.cc
  23. +2
    -6
      src/common/graph/utils/tensor_utils.cc
  24. +1
    -2
      src/common/graph/utils/type_utils.cc

+ 0
- 1
src/common/graph/CMakeLists.txt View File

@@ -71,6 +71,5 @@ target_link_libraries(graph PRIVATE
${PROTOBUF_LIBRARY}
${c_sec}
${slog}
${error_manager}
rt
dl)

+ 0
- 13
src/common/graph/compute_graph.cc View File

@@ -106,15 +106,6 @@ ComputeGraph::Vistor<NodePtr> ComputeGraph::AllGraphNodes(std::vector<std::share
return Vistor<NodePtr>(shared_from_this(), all_nodes);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetNodes(
bool is_unknown_shape) const {
if (is_unknown_shape) {
return GetDirectNode();
} else {
return GetAllNodes();
}
}

size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const {
@@ -506,10 +497,6 @@ ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<Compute
if (name != subgraph->GetName()) {
GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str());
}
if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) {
GE_LOGE("The subgraph %s existed", name.c_str());
return GRAPH_PARAM_INVALID;
}
sub_graph_.push_back(subgraph);
names_to_subgraph_[name] = subgraph;
return GRAPH_SUCCESS;


+ 0
- 4
src/common/graph/debug/ge_op_types.h View File

@@ -34,16 +34,12 @@ GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims");
GE_REGISTER_OPTYPE(SWITCH, "Switch");
GE_REGISTER_OPTYPE(MERGE, "Merge");
GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge");
GE_REGISTER_OPTYPE(ENTER, "Enter");
GE_REGISTER_OPTYPE(REFENTER, "RefEnter");
GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration");
GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration");
GE_REGISTER_OPTYPE(CONSTANT, "Const");
GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder");
GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp");
GE_REGISTER_OPTYPE(GETNEXT, "GetNext");
GE_REGISTER_OPTYPE(INITDATA, "InitData");
GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity");
GE_REGISTER_OPTYPE(ANN_DATA, "AnnData");

GE_REGISTER_OPTYPE(CONSTANTOP, "Constant");


+ 21
- 69
src/common/graph/format_refiner.cc View File

@@ -41,9 +41,11 @@ using namespace ge;
using namespace std;
namespace ge {
namespace {
const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE};
const string kIsGraphInferred = "_is_graph_inferred";
RefRelations reflection_builder;
static const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE};
static bool net_format_is_nd = true;
static Format g_user_set_format = FORMAT_ND;
static bool is_first_infer = true;
static RefRelations reflection_builder;
} // namespace

graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection,
@@ -70,49 +72,9 @@ graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &re
return GRAPH_SUCCESS;
}

graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) {
// 5 meas dim num
if (node_ptr->GetType() != "BiasAdd") {
return GRAPH_SUCCESS;
}
std::unordered_map<string, Format> kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}};
for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) {
auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i);
GE_CHECK_NOTNULL(in_desc);
if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num
continue;
}
auto format = in_desc->GetOriginFormat();
auto key = TypeUtils::FormatToSerialString(format);
auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key];
in_desc->SetOriginFormat(fixed_format);
in_desc->SetFormat(fixed_format);
GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i,
node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::FormatToSerialString(fixed_format).c_str());
}
for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) {
auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i);
GE_CHECK_NOTNULL(out_desc);
if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num
continue;
}
auto format = out_desc->GetOriginFormat();
auto key = TypeUtils::FormatToSerialString(format);
auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key];
out_desc->SetOriginFormat(fixed_format);
out_desc->SetFormat(fixed_format);
GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i,
node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::FormatToSerialString(fixed_format).c_str());
}
return GRAPH_SUCCESS;
}

graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) {
GE_CHECK_NOTNULL(graph);
graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) {
GE_CHECK_NOTNULL(op_desc);
if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) {
if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) {
ConstGeTensorPtr tensor_value;
if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) {
GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str());
@@ -133,7 +95,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std
}
anchor_points.clear();
// Get all anchor point nodes and switch nodes
for (auto &node_ptr : graph->GetAllNodes()) {
for (const auto &node_ptr : graph->GetAllNodes()) {
if (node_ptr == nullptr) {
return GRAPH_FAILED;
}
@@ -141,7 +103,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std
if (op_desc == nullptr) {
return GRAPH_FAILED;
}
graphStatus status = RefreshConstantOutProcess(graph, op_desc);
graphStatus status = RefreshConstantOutProcess(op_desc);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "refresh constant out process failed!");
return GRAPH_FAILED;
@@ -173,16 +135,6 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std
if (!node_is_all_nd) {
continue;
}
// special process for biasAdd op
// In tensorflow, biasAdd's format is alwayse NHWC even though set the arg
// "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism
// so here do special process
status = BiasAddFormatFixProcess(node_ptr);
if (status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "fix biasAdd process failed!");
return GRAPH_FAILED;
}

GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str());
anchor_points.push_back(node_ptr);
}
@@ -392,11 +344,14 @@ void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor
}
}

graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes,
ge::Format data_format,
void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = is_first; }

graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format,
std::unordered_map<ge::NodePtr, bool> &node_status) {
if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) {
GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph),
bool is_internal_format = TypeUtils::IsInternalFormat(data_format);
bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND);
if (!need_process) {
GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer,
TypeUtils::FormatToSerialString(data_format).c_str());
return GRAPH_SUCCESS;
}
@@ -455,6 +410,8 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph)
std::vector<ge::NodePtr> anchor_points;
std::vector<ge::NodePtr> data_nodes;
// global net format
net_format_is_nd = true;
g_user_set_format = FORMAT_ND;

if (graph == nullptr) {
GELOGE(GRAPH_FAILED, "input graph is null");
@@ -491,15 +448,10 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph)
/// format for these data nodes.
/// Notice: ignore 5D formats
auto data_format = graph->GetDataFormat();
status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status);
(void)AttrUtils::SetBool(graph, kIsGraphInferred, true);
status = DataNodeFormatProcess(data_nodes, data_format, node_status);
// Set infer flag to false
SetInferOrigineFormatFlag(false);

return status;
}

bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) {
bool is_graph_inferred = false;
return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred);
}
} // namespace ge

+ 4
- 4
src/common/graph/format_refiner.h View File

@@ -30,9 +30,10 @@ namespace ge {
class FormatRefiner {
public:
static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph);
static void SetInferOrigineFormatFlag(bool is_first = true);

private:
static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc);
static graphStatus RefreshConstantOutProcess(const OpDescPtr &op_desc);
static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points,
std::vector<ge::NodePtr> &data_nodes,
std::unordered_map<ge::NodePtr, bool> &node_status);
@@ -42,9 +43,8 @@ class FormatRefiner {
std::unordered_map<ge::NodePtr, bool> &node_status);
static graphStatus ForwardInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node,
std::unordered_map<ge::NodePtr, bool> &node_status);
static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes,
ge::Format data_format, std::unordered_map<ge::NodePtr, bool> &node_status);
static bool IsGraphInferred(const ComputeGraphPtr &graph);
static graphStatus DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format,
std::unordered_map<ge::NodePtr, bool> &node_status);
};
} // namespace ge
#endif // COMMON_GRAPH_FORMAT_REFINER_H_

+ 1
- 20
src/common/graph/ge_attr_define.cc View File

@@ -725,10 +725,6 @@ const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name";

const std::string ATTR_MODEL_CORE_TYPE = "core_type";

const std::string ATTR_MODEL_ATC_VERSION = "atc_version";

const std::string ATTR_MODEL_OPP_VERSION = "opp_version";

// Public attribute
const std::string ATTR_NAME_IMPLY_TYPE = "imply_type";

@@ -938,7 +934,7 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace";

const std::string MODEL_ATTR_SESSION_ID = "session_id";

// lx fusion
// l1 fusion and other fusion in future
const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id";
const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key";
const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key";
@@ -952,17 +948,9 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1
const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion";
const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split";
const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed";
const std::string ATTR_DATA_DUMP_REF = "_datadump_ref";
const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion";
const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id";
const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion";
const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag";
const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr";
const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size";

// Op debug attrs
const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag";
const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode";

// Atomic addr clean attrs
const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index";
@@ -1027,11 +1015,4 @@ const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op";
// used for allreduce tailing optimization
const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group";
const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node";

// dynamic shape attr
const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr";
const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index";

// for fusion op plugin
const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type";
} // namespace ge

+ 0
- 11
src/common/graph/ge_tensor.cc View File

@@ -220,7 +220,6 @@ const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape";
const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format";
const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type";
const string TENSOR_UTILS_SHAPE_RANGE = "shape_range";
const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index";

GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {}

@@ -568,16 +567,6 @@ DataType GeTensorDesc::GetOriginDataType() const {
return TypeUtils::SerialStringToDataType(origin_data_type_str);
}

std::vector<uint32_t> GeTensorDesc::GetRefPortIndex() const {
vector<uint32_t> ref_port_index;
(void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index);
return ref_port_index;
}

void GeTensorDesc::SetRefPortByIndex(const std::vector<uint32_t> &index) {
(void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index);
}

graphStatus GeTensorDesc::IsValid() const {
auto dtype = this->GetDataType();
auto format = this->GetFormat();


+ 1
- 1
src/common/graph/graph.cc View File

@@ -210,7 +210,7 @@ class GraphImpl {

graphStatus FindOpByName(const string &name, ge::Operator &op) const {
auto it = op_list_.find(name);
GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str());
GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "Error: there is no op: %s.", name.c_str());
op = it->second;
return GRAPH_SUCCESS;
}


+ 8
- 66
src/common/graph/graph.mk View File

@@ -77,7 +77,6 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -95,35 +94,10 @@ LOCAL_CPPFLAGS += -fexceptions

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/operator_factory.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for host
include $(CLEAR_VARS)
LOCAL_MODULE := fwk_stub/libgraph

LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2
LOCAL_CPPFLAGS += -fexceptions

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/attr_value.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/atc/lib64/stub/graph.cc \
../../out/atc/lib64/stub/operator.cc \
../../out/atc/lib64/stub/tensor.cc \
../../out/atc/lib64/stub/operator_factory.cc \


LOCAL_SHARED_LIBRARIES :=
@@ -148,7 +122,6 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -169,38 +142,10 @@ LOCAL_CFLAGS += -O2

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/operator_factory.cc \


LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

ifeq ($(device_os),android)
LOCAL_LDFLAGS := -ldl
endif

LOCAL_MULTILIB := 64
LOCAL_PROPRIETARY_MODULE := true

include $(BUILD_SHARED_LIBRARY)

#compiler for device
include $(CLEAR_VARS)
LOCAL_MODULE := fwk_stub/libgraph

LOCAL_CFLAGS += -O2

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/attr_value.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/atc/lib64/stub/graph.cc \
../../out/atc/lib64/stub/operator.cc \
../../out/atc/lib64/stub/tensor.cc \
../../out/atc/lib64/stub/operator_factory.cc \


LOCAL_SHARED_LIBRARIES :=
@@ -229,7 +174,6 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \
libprotobuf \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -255,7 +199,6 @@ LOCAL_STATIC_LIBRARIES := \
LOCAL_SHARED_LIBRARIES := \
libc_sec \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl

@@ -279,7 +222,6 @@ LOCAL_STATIC_LIBRARIES := \
LOCAL_SHARED_LIBRARIES := \
libc_sec \
libslog \
liberror_manager \

LOCAL_LDFLAGS := -lrt -ldl



+ 13
- 12
src/common/graph/model_serialize.cc View File

@@ -88,8 +88,10 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_
}

bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) {
GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null.");
GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");
if (op_desc == nullptr || op_def_proto == nullptr) {
GELOGE(GRAPH_FAILED, "Input Para Invalid");
return false;
}
if (op_desc->op_def_.GetProtoMsg() != nullptr) {
*op_def_proto = *op_desc->op_def_.GetProtoMsg();
// Delete unnecessary attr
@@ -128,17 +130,16 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op
for (const std::string &name : op_desc->GetSubgraphInstanceNames()) {
op_def_proto->add_subgraph_name(name);
}
if (!op_desc->output_name_idx_.empty()) {
proto::AttrDef key;
proto::AttrDef value;
for (auto &item : op_desc->output_name_idx_) {
key.mutable_list()->add_s(item.first);
value.mutable_list()->add_i(item.second);
}
auto op_desc_attr = op_def_proto->mutable_attr();
op_desc_attr->insert({"_output_name_key", key});
op_desc_attr->insert({"_output_name_value", value});

proto::AttrDef key;
proto::AttrDef value;
for (auto &item : op_desc->output_name_idx_) {
key.mutable_list()->add_s(item.first);
value.mutable_list()->add_i(item.second);
}
auto op_desc_attr = op_def_proto->mutable_attr();
op_desc_attr->insert({"_output_name_key", key});
op_desc_attr->insert({"_output_name_value", value});
}
return true;
}


+ 13
- 34
src/common/graph/node.cc View File

@@ -26,7 +26,6 @@
#include "utils/ge_ir_utils.h"
#include "utils/node_utils.h"
#include "utils/op_desc_utils.h"
#include "common/util/error_manager/error_manager.h"

using std::string;
using std::vector;
@@ -155,7 +154,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(cons
const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode();
const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode();
if (peer_node == nullptr || r_peer_node == nullptr) {
GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ",
GELOGE(GRAPH_FAILED, "Error: anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ",
this->GetName().c_str(), i, j);
return false;
}
@@ -435,11 +434,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::Get

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const {
if (idx < 0 || idx >= static_cast<int>(in_data_anchors_.size())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19019", {"opname", "index", "anchorname", "optype"},
{GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
GELOGE(GRAPH_FAILED, "the node doesn't have %d th in_data_anchor, node %s:%s", idx, GetType().c_str(),
GetName().c_str());
return nullptr;
} else {
return in_data_anchors_[idx];
@@ -449,10 +445,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAn
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const {
// Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_
if (idx < -1 || idx >= static_cast<int>(in_data_anchors_.size())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19019", {"opname", "index", "anchorname", "optype"},
{GetName().c_str(), std::to_string(idx), "in_anchor", GetType().c_str()});
GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str());
GELOGW("the node doesn't have %d th in_anchor, node %s:%s", idx, GetType().c_str(), GetName().c_str());
return nullptr;
} else {
// Return control anchor
@@ -468,15 +461,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int i
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const {
// Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_
if (idx < -1 || idx >= static_cast<int>(out_data_anchors_.size())) {
ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"},
{
GetName().c_str(),
std::to_string(idx),
"out_anchor",
GetType().c_str(),
});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_anchor, node %s:%s", idx, GetType().c_str(),
GetName().c_str());
return nullptr;
} else {
// Return control anchor
@@ -491,11 +477,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const {
if (idx < 0 || idx >= static_cast<int>(out_data_anchors_.size())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19019", {"opname", "index", "anchorname", "optype"},
{GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()});
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx,
GetType().c_str());
GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_data_anchor, node %s:%s", idx, GetType().c_str(),
GetName().c_str());
return nullptr;
} else {
return out_data_anchors_[idx];
@@ -750,15 +733,11 @@ graphStatus Node::Verify() const {
GELOGW("in anchor ptr is null");
continue;
}
bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type ||
op_->GetType() == const_type || op_->GetType() == variable_type ||
op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0;
if (!valid_anchor) {
ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"name", "index"},
{GetName(), std::to_string(in_anchor_ptr->GetIdx())});
GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx());
return GRAPH_FAILED;
}
GE_CHK_BOOL_RET_STATUS(
op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type ||
op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) ||
in_anchor_ptr->GetPeerAnchors().size() > 0,
GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx());
}

string frameworkop_type = "FrameworkOp";


+ 13
- 40
src/common/graph/op_desc.cc View File

@@ -19,7 +19,6 @@
#include "debug/ge_util.h"
#include "external/graph/operator.h"
#include "framework/common/debug/ge_log.h"
#include "common/util/error_manager/error_manager.h"
#include "graph/ge_attr_value.h"
#include "graph/ge_tensor.h"
#include "graph/operator_factory_impl.h"
@@ -471,25 +470,6 @@ GeTensorDesc OpDesc::GetInputDesc(const string &name) const {
return *(inputs_desc_[it->second].get());
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const {
GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index);
if (inputs_desc_[index] == nullptr) {
return nullptr;
}
GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid");
return inputs_desc_[index];
}

GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.find(name);
if (it == input_name_idx.end()) {
GELOGW("Failed to get [%s] input desc", name.c_str());
return nullptr;
}
return MutableInputDesc(it->second);
}

GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const {
auto input_name_idx = GetAllInputName();
vector<string> names;
@@ -516,6 +496,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(cons

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const {
GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index);
if (inputs_desc_[index] == nullptr) {
return nullptr;
}
GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid");
return inputs_desc_[index];
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllInputsDesc() const {
vector<GeTensorDesc> temp{};
for (const auto &it : inputs_desc_) {
@@ -620,15 +609,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu
return outputs_desc_[index];
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const {
auto it = output_name_idx_.find(name);
if (it == output_name_idx_.end()) {
GELOGW("Failed to get [%s] output desc", name.c_str());
return nullptr;
}
return MutableOutputDesc(it->second);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const {
return static_cast<uint32_t>(outputs_desc_.size());
}
@@ -902,22 +882,15 @@ graphStatus OpDesc::CommonVerify() const {
// Checking shape of all inputs
vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims();
for (int64_t dim : ishape) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
dim < -2, ErrorManager::GetInstance().ATCReportErrMessage(
"E19014", {"opname", "value", "reason"},
{GetName(), "input " + iname + " shape", "contains negative or zero dimension"});
return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(),
iname.c_str());
GE_CHK_BOOL_RET_STATUS(dim >= -2, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.",
iname.c_str());
}
}
// Check all attributes defined
const auto &all_attributes = GetAllAttrs();
for (const auto &name : GetAllAttrNames()) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
all_attributes.find(name) == all_attributes.end(),
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
{GetName(), "attribute " + name, "is empty"});
return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str());
GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED,
"operator attribute %s is empty.", name.c_str());
}

return GRAPH_SUCCESS;


+ 40
- 33
src/common/graph/operator.cc View File

@@ -36,8 +36,6 @@
#include "graph/op_desc.h"
#include "graph/runtime_inference_context.h"
#include "graph/usr_types.h"
#include "graph/utils/node_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "utils/graph_utils.h"
#include "utils/op_desc_utils.h"
#include "utils/tensor_adapter.h"
@@ -59,7 +57,8 @@ using std::vector;
namespace ge {
class OpIO {
public:
OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {}
explicit OpIO(const string &name, int index, const OperatorImplPtr &owner)
: name_(name), index_(index), owner_(owner) {}

~OpIO() = default;

@@ -547,46 +546,56 @@ Operator &Operator::AddControlInput(const Operator &src_oprt) {
}

graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const {
GE_CHECK_NOTNULL(operator_impl_);
auto node_ptr = operator_impl_->GetNode();
if (node_ptr != nullptr) {
if (operator_impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "operator impl is nullptr.");
return GRAPH_FAILED;
}
ge::ConstNodePtr node_ptr = operator_impl_->GetNode();
if (node_ptr) {
// For inner compute graph
auto op_desc = node_ptr->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "op_desc is nullptr.");
return GRAPH_FAILED;
}
auto index = op_desc->GetInputIndexByName(dst_name);
auto in_data_anchor = node_ptr->GetInDataAnchor(index);
GE_CHECK_NOTNULL(in_data_anchor);
if (in_data_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "in_data_anchor is nullptr.");
return GRAPH_FAILED;
}
auto out_data_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(out_data_anchor);
auto peer_node = out_data_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(peer_node);
auto peer_op_desc = peer_node->GetOpDesc();
GE_CHECK_NOTNULL(peer_op_desc);
auto peer_op_type = peer_op_desc->GetType();
if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) {
auto const_op_impl = ComGraphMakeShared<OperatorImpl>(peer_node);
GE_CHECK_NOTNULL(const_op_impl);
Operator const_op(std::move(const_op_impl));
return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
} else if (peer_op_type == DATA) {
auto parent_node = NodeUtils::GetParentInput(peer_node);
while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) {
parent_node = NodeUtils::GetParentInput(parent_node);
}
if ((parent_node != nullptr) &&
((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) {
auto const_op_impl = ComGraphMakeShared<OperatorImpl>(parent_node);
GE_CHECK_NOTNULL(const_op_impl);
Operator const_op(std::move(const_op_impl));
return const_op.GetAttr(ATTR_NAME_WEIGHTS, data);
if (out_data_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "out_data_anchor is nullptr.");
return GRAPH_FAILED;
}
std::shared_ptr<Node> peer_node_ptr = out_data_anchor->GetOwnerNode();
if (peer_node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "peer_node_ptr is nullptr.");
return GRAPH_FAILED;
}
ge::OperatorImplPtr operator_impl_ptr = nullptr;
operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(peer_node_ptr);
if (operator_impl_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed");
return GRAPH_FAILED;
}
Operator const_op(std::move(operator_impl_ptr));
if (peer_node_ptr->GetOpDesc() != nullptr) {
const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType();
if (op_descType == CONSTANTOP) {
return const_op.GetAttr(op::Constant::name_attr_value(), data);
} else if (op_descType == CONSTANT) {
return const_op.GetAttr(op::Const::name_attr_value(), data);
}
}

// Try get from runtime inference context
auto session_id = std::to_string(GetContext().SessionId());
RuntimeInferenceContext *runtime_infer_ctx = nullptr;
if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) {
GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str());
auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data);
auto ret = runtime_infer_ctx->GetTensor(peer_node_ptr->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data);
if (ret == GRAPH_SUCCESS) {
return GRAPH_SUCCESS;
}
@@ -595,8 +604,6 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co
// For outer graph
return GetInputConstDataOut(dst_name, data);
}
auto op_name = operator_impl_->GetName();
GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str());
return GRAPH_FAILED;
}
graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const {


+ 0
- 2
src/common/graph/option/ge_context.cc View File

@@ -85,8 +85,6 @@ uint32_t GEContext::DeviceId() { return device_id_; }

uint64_t GEContext::TraceId() { return trace_id_; }

void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; }

void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; }

} // namespace ge

+ 0
- 4
src/common/graph/ref_relation.cc View File

@@ -242,10 +242,6 @@ void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &r
int sub_graph_idx = 0;
for (const auto &name : sub_graph_names) {
auto sub_graph = root_graph.GetSubgraph(name);
if (sub_graph == nullptr) {
GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str());
continue;
}
for (const auto &sub_graph_node : sub_graph->GetDirectNode()) {
auto sub_graph_node_type = sub_graph_node->GetType();



+ 16
- 156
src/common/graph/shape_refiner.cc View File

@@ -37,115 +37,6 @@

namespace ge {
namespace {
const uint32_t kWhileBodySubGraphIdx = 1;

graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) {
GELOGD("Enter reverse brush while body subgraph process!");

auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx);
if (sub_graph_body == nullptr) {
GELOGE(GRAPH_FAILED, "Get while body graph failed!");
return GRAPH_FAILED;
}

for (const auto &node_sub : sub_graph_body->GetAllNodes()) {
if (node_sub->GetInDataNodes().size() == 0) {
continue;
}

for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) {
auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i);
(void)input_desc->SetUnknownDimNumShape();
}
for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) {
auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i);
(void)output_desc->SetUnknownDimNumShape();
}
}

return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
GELOGD("Enter update parent node shape for class branch op process");
// check sub_graph shape.If not same ,do unknown shape process
for (size_t i = 0; i < ref_out_tensors.size(); i++) {
if (ref_out_tensors[i].empty()) {
continue;
}
auto ref_out_tensor = ref_out_tensors[i].at(0);
ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
for (auto &tensor : ref_out_tensors[i]) {
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str());
return GRAPH_FAILED;
}
auto shape = tensor.MutableShape();
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
break;
}
for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
continue;
}
GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
(void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
}
}
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
}
return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
GELOGD("Enter update parent node shape for class while op process");
if (ref_data_tensors.size() != ref_out_tensors.size()) {
GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(),
ref_data_tensors.size(), ref_out_tensors.size());
return GRAPH_FAILED;
}
for (size_t i = 0; i < ref_data_tensors.size(); i++) {
if (ref_out_tensors[i].size() != 1) {
GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!");
return GRAPH_FAILED;
}
}
bool is_need_reverse_brush = false;
// check input and output
for (size_t i = 0; i < ref_out_tensors.size(); i++) {
if (ref_out_tensors[i].empty()) {
continue;
}
auto ref_out_tensor = ref_out_tensors[i].at(0);
auto tmp_shape = ref_out_tensor.MutableShape();
// ref_i's data and output tensor shape should be same
for (auto &tensor : ref_data_tensors[i]) {
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str());
return GRAPH_FAILED;
}
auto shape = tensor.MutableShape();
if (shape.GetDims() != tmp_shape.GetDims()) {
ref_out_tensor.SetUnknownDimNumShape();
is_need_reverse_brush = true;
break;
}
}
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
}
// reverse refresh while body shape
if (is_need_reverse_brush) {
return ReverseBrushWhileBodySubGraph(node);
}
return GRAPH_SUCCESS;
}

graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
@@ -207,37 +98,6 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
}
return GRAPH_SUCCESS;
}

graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph, NodePtr &netoutput,
const ConstNodePtr &node,
std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
auto sub_nodes = sub_graph->GetDirectNode();
for (size_t i = sub_nodes.size(); i > 0; --i) {
auto sub_node = sub_nodes.at(i - 1);
if (sub_node->GetType() == NETOUTPUT) {
netoutput = sub_node;
}
if (sub_node->GetType() == DATA) {
if (sub_node->GetOpDesc() == nullptr) {
return GRAPH_FAILED;
}

int ref_i;
if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
return GRAPH_FAILED;
}
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(),
ref_i, node->GetAllInDataAnchorsSize());
return GRAPH_FAILED;
}
ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
}
}
return GRAPH_SUCCESS;
}

graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
@@ -245,10 +105,7 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
return GRAPH_SUCCESS;
}

std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());

for (const auto &name : sub_graph_names) {
if (name.empty()) {
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
@@ -260,9 +117,13 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
return GRAPH_FAILED;
}
NodePtr netoutput = nullptr;
auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
if (ret != GRAPH_SUCCESS) {
return ret;
auto sub_nodes = sub_graph->GetDirectNode();
for (size_t i = sub_nodes.size(); i > 0; --i) {
auto sub_node = sub_nodes.at(i - 1);
if (sub_node->GetType() == NETOUTPUT) {
netoutput = sub_node;
break;
}
}
if (netoutput == nullptr) {
GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str());
@@ -289,17 +150,19 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
continue;
}
GELOGI("Parent node index of edge desc is %d", ref_i);
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i));
if (output_desc == nullptr) {
GE_LOGE(
"The ref index(%d) on the input %d of netoutput %s on the sub graph %s "
"parent node %s are incompatible, outputs num %u",
ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(),
node->GetAllOutDataAnchorsSize());
return GRAPH_FAILED;
}
ref_out_tensors[ref_i].emplace_back(*edge_desc);
op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc);
}
}

if (node->GetType() == WHILE) {
return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors);
}
return UpdateParentNodeForBranch(node, ref_out_tensors);
return GRAPH_SUCCESS;
}
} // namespace
void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) {
@@ -307,9 +170,6 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str
GELOGE(GRAPH_FAILED, "node is null");
return;
}
if (!IsLogEnable(GE, DLOG_DEBUG)) {
return;
}
ge::OpDescPtr op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return );
std::string str;


+ 6
- 0
src/common/graph/stub/Makefile View File

@@ -0,0 +1,6 @@
inc_path := $(shell pwd)/inc/external/
out_path := $(shell pwd)/out/atc/lib64/stub/
stub_path := $(shell pwd)/common/graph/stub/

mkdir_stub := $(shell mkdir -p $(out_path))
graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path))

+ 573
- 0
src/common/graph/stub/gen_stubapi.py View File

@@ -0,0 +1,573 @@
import os
import re
import sys
import logging

logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s',
level=logging.INFO)

"""
this attr is used for symbol table visible
"""
GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY'

"""
generate stub func body by return type
"""
RETURN_STATEMENTS = {
'graphStatus': ' return GRAPH_SUCCESS;',
'Status': ' return SUCCESS;',
'Graph': ' return Graph();',
'Graph&': ' return *this;',
'Format': ' return Format();',
'Format&': ' return *this;',
'Shape': ' return Shape();',
'Shape&': ' return *this;',
'TensorDesc': ' return TensorDesc();',
'TensorDesc&': ' return *this;',
'Tensor': ' return Tensor();',
'Tensor&': ' return *this;',
'Operator': ' return Operator();',
'Operator&': ' return *this;',
'Ptr': ' return nullptr;',
'std::string': ' return "";',
'std::string&': ' return "";',
'string': ' return "";',
'int': ' return 0;',
'DataType': ' return DT_FLOAT;',
'InferenceContextPtr': ' return nullptr;',
'SubgraphBuilder': ' return nullptr;',
'OperatorImplPtr': ' return nullptr;',
'OutHandler': ' return nullptr;',
'std::vector<std::string>': ' return {};',
'std::vector<int64_t>': ' return {};',
'std::map': ' return {};',
'uint32_t': ' return 0;',
'int64_t': ' return 0;',
'uint64_t': ' return 0;',
'size_t': ' return 0;',
'float': ' return 0.0f;',
'bool': ' return false;',
}

"""
max code len per line in hua_wei software programming specifications
"""
max_code_len_per_line = 100

"""
white_list_for_debug, include_dir_key_words is to
determines which header files to generate cc files from
when DEBUG on
"""
white_list_for_debug = ["operator.h", "tensor.h",
"graph.h", "operator_factory.h",
"ge_ir_build.h"]
include_dir_key_words = ["ge", "graph"]
DEBUG = True


def need_generate_func(func_line):
"""
:param func_line:
:return:
"""
if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \
or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"):
return False
return True


def file_endswith_white_list_suffix(file):
"""
:param file:
:return:
"""
if DEBUG:
for suffix in white_list_for_debug:
if file.endswith(suffix):
return True
return False
else:
return True


"""
belows are patterns used for analyse .h file
"""
# pattern function
pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after
([a-zA-Z~_] # void int likely
.*
[)] #we find )
(?!.*{) # we do not want the case int abc() const { return 1;}
.*)
(;.*) #we want to find ; and after for we will replace these later
\n$
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

# pattern comment
pattern_comment = re.compile(r'^\s*//')
pattern_comment_2_start = re.compile(r'^\s*/[*]')
pattern_comment_2_end = re.compile(r'[*]/\s*$')
# pattern define
pattern_define = re.compile(r'^\s*#define')
pattern_define_return = re.compile(r'\\\s*$')
# blank line
pattern_blank_line = re.compile(r'^\s*$')
# virtual,explicit,friend,static
pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)')
# lead space
pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]')
# functions will have patterns such as func ( or func(
# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist
# format like :"operator = ()"
pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]')
# template
pattern_template = re.compile(r'^\s*template')
pattern_template_end = re.compile(r'>\s*$')
# namespace
pattern_namespace = re.compile(r'namespace.*{')
# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with
pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR)
# {}
pattern_start = re.compile('{')
pattern_end = re.compile('}')

line_index = 0


class H2CC(object):
def __init__(self, input_file, output_file, shared_includes_content):
"""
:param input_file:
:param output_file:
:param shared_includes_content:
"""
self.input_file = input_file
self.output_file = output_file
self.shared_includes_content = shared_includes_content
self.line_index = 0
self.input_fd = open(self.input_file, 'r')
self.input_content = self.input_fd.readlines()
self.output_fd = open(self.output_file, 'w')

# The state may be normal_now(in the middle of {}),class_now,namespace_now
self.stack = []
self.stack_class = []
self.stack_template = []
# record funcs generated by h2cc func
self.func_list_exist = []

def __del__(self):
self.input_fd.close()
self.output_fd.close()
del self.stack
del self.stack_class
del self.stack_template
del self.func_list_exist

def just_skip(self):
# skip blank line or comment
if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search(
self.input_content[self.line_index]): # /n or comment using //
self.line_index += 1
if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /*
while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */
self.line_index += 1
self.line_index += 1
# skip define
if pattern_define.search(self.input_content[self.line_index]):
while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search(
self.input_content[self.line_index]):
self.line_index += 1
self.line_index += 1

def write_inc_content(self):
for shared_include_content in self.shared_includes_content:
self.output_fd.write(shared_include_content)

def h2cc(self):
"""
:return:
"""
logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file)
global pattern_comment
global pattern_comment_2_start
global pattern_comment_2_end
global pattern_blank_line
global pattern_func
global pattern_keyword
global pattern_leading_space
global pattern_func_name
global pattern_template
global pattern_template_end
global pattern_namespace
global pattern_class
global pattern_start
global pattern_end
global line_index
# write inc content
self.write_inc_content()
# core processing cycle, process the input .h file by line
while self.line_index < len(self.input_content):
# handle comment and blank line
self.just_skip()

# match namespace
self.handle_namespace()

# match template
template_string = self.handle_template()
# match class
line = self.input_content[self.line_index]
match_class = pattern_class.search(line)
match_start = pattern_start.search(line)
handle_class_result = self.handle_class(template_string, line, match_start, match_class)
if handle_class_result == "continue":
continue

# match "}"
handle_stack_result = self.handle_stack(match_start)
if handle_stack_result == "continue":
continue
# handle func
handle_func1_result, line, start_i = self.handle_func1(line)
if handle_func1_result == "continue":
continue

# here means func is found
# delete key word
line = pattern_keyword.sub('', line)
logging.info("line[%s]", line)

# Class member function
# if friend we will not add class name
friend_match = re.search('friend ', line)
if len(self.stack_class) > 0 and not friend_match:
line, func_name = self.handle_class_member_func(line, template_string)
# Normal functions
else:
line, func_name = self.handle_normal_func(line, template_string)

need_generate = need_generate_func(line)
# func body
line += self.implement_function(line)
# comment
line = self.gen_comment(start_i) + line
# write to out file
self.write_func_content(line, func_name, need_generate)
# next loop
self.line_index += 1

logging.info('Added %s functions', len(self.func_list_exist))
logging.info('Successfully converted,please see ' + self.output_file)

def handle_func1(self, line):
"""
:param line:
:return:
"""
find1 = re.search('[(]', line)
if not find1:
self.line_index += 1
return "continue", line, None
find2 = re.search('[)]', line)
start_i = self.line_index
space_match = pattern_leading_space.search(line)
# deal with
# int abc(int a,
# int b)
if find1 and (not find2):
self.line_index += 1
line2 = self.input_content[self.line_index]
if space_match:
line2 = re.sub('^' + space_match.group(1), '', line2)
line += line2
while self.line_index < len(self.input_content) and (not re.search('[)]', line2)):
self.line_index += 1
line2 = self.input_content[self.line_index]
line2 = re.sub('^' + space_match.group(1), '', line2)
line += line2

match_start = pattern_start.search(self.input_content[self.line_index])
match_end = pattern_end.search(self.input_content[self.line_index])
if match_start: # like ) { or ) {} int the last line
if not match_end:
self.stack.append('normal_now')
ii = start_i
while ii <= self.line_index:
ii += 1
self.line_index += 1
return "continue", line, start_i
logging.info("line[%s]", line)
# ' int abc();'->'int abc()'
(line, match) = pattern_func.subn(r'\2\n', line)
logging.info("line[%s]", line)
# deal with case:
# 'int \n abc(int a, int b)'
if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]):
line = self.input_content[start_i - 1] + line
line = line.lstrip()
if not match:
self.line_index += 1
return "continue", line, start_i
return "pass", line, start_i

def handle_stack(self, match_start):
"""
:param match_start:
:return:
"""
line = self.input_content[self.line_index]
match_end = pattern_end.search(line)
if match_start:
self.stack.append('normal_now')
if match_end:
top_status = self.stack.pop()
if top_status == 'namespace_now':
self.output_fd.write(line + '\n')
elif top_status == 'class_now':
self.stack_class.pop()
self.stack_template.pop()
if match_start or match_end:
self.line_index += 1
return "continue"

if len(self.stack) > 0 and self.stack[-1] == 'normal_now':
self.line_index += 1
return "continue"
return "pass"

def handle_class(self, template_string, line, match_start, match_class):
"""
:param template_string:
:param line:
:param match_start:
:param match_class:
:return:
"""
if match_class: # we face a class
self.stack_template.append(template_string)
self.stack.append('class_now')
class_name = match_class.group(3)

# class template specializations: class A<u,Node<u> >
if '<' in class_name:
k = line.index('<')
fit = 1
for ii in range(k + 1, len(line)):
if line[ii] == '<':
fit += 1
if line[ii] == '>':
fit -= 1
if fit == 0:
break
class_name += line[k + 1:ii + 1]
logging.info('class_name[%s]', class_name)
self.stack_class.append(class_name)
while not match_start:
self.line_index += 1
line = self.input_content[self.line_index]
match_start = pattern_start.search(line)
self.line_index += 1
return "continue"
return "pass"

def handle_template(self):
line = self.input_content[self.line_index]
match_template = pattern_template.search(line)
template_string = ''
if match_template:
match_template_end = pattern_template_end.search(line)
template_string = line
while not match_template_end:
self.line_index += 1
line = self.input_content[self.line_index]
template_string += line
match_template_end = pattern_template_end.search(line)
self.line_index += 1
return template_string

def handle_namespace(self):
line = self.input_content[self.line_index]
match_namespace = pattern_namespace.search(line)
if match_namespace: # we face namespace
self.output_fd.write(line + '\n')
self.stack.append('namespace_now')
self.line_index += 1

def handle_normal_func(self, line, template_string):
template_line = ''
self.stack_template.append(template_string)
if self.stack_template[-1] != '':
template_line = re.sub(r'\s*template', 'template', self.stack_template[-1])
# change '< class T = a, class U = A(3)>' to '<class T, class U>'
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
template_line = re.sub(r'\s*=.*,', ',', template_line)
template_line = re.sub(r'\s*=.*', '', template_line)
line = re.sub(r'\s*=.*,', ',', line)
line = re.sub(r'\s*=.*\)', ')', line)
line = template_line + line
self.stack_template.pop()
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
logging.info("line[%s]", line)
logging.info("func_name[%s]", func_name)
return line, func_name

def handle_class_member_func(self, line, template_string):
template_line = ''
x = ''
if template_string != '':
template_string = re.sub(r'\s*template', 'template', template_string)
template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string)
template_string = re.sub(r'\s*=.*,', ',', template_string)
template_string = re.sub(r'\s*=.*', '', template_string)
if self.stack_template[-1] != '':
if not (re.search(r'<\s*>', stack_template[-1])):
template_line = re.sub(r'^\s*template', 'template', stack_template[-1])
if not (re.search(r'<.*>', self.stack_class[-1])):
# for x we get like template<class T, typename U> -> <T,U>
x = re.sub(r'template\s*<', '<', template_line) # remove template -> <class T, typename U>
x = re.sub(r'\n', '', x)
x = re.sub(r'\s*=.*,', ',', x)
x = re.sub(r'\s*=.*\>', '>', x)
x = x.rstrip() # remove \n
x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '',
x) # remove class,typename -> <T, U>
x = re.sub(r'<\s+', '<', x)
x = re.sub(r'\s+>', '>', x)
x = re.sub(r'\s+,', ',', x)
x = re.sub(r',\s+', ', ', x)
line = re.sub(r'\s*=\s+0', '', line)
line = re.sub(r'\s*=\s+.*,', ',', line)
line = re.sub(r'\s*=\s+.*\)', ')', line)
logging.info("x[%s]\nline[%s]", x, line)
# if the function is long, void ABC::foo()
# breaks into two lines void ABC::\n foo()
temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1)
if len(temp_line) > max_code_len_per_line:
line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1)
else:
line = temp_line
logging.info("line[%s]", line)
# add template as the above if there is one
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
template_line = re.sub(r'\s*=.*,', ',', template_line)
template_line = re.sub(r'\s*=.*', '', template_line)
line = template_line + template_string + line
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
logging.info("line[%s]", line)
logging.info("func_name[%s]", func_name)
return line, func_name

def write_func_content(self, content, func_name, need_generate):
if not (func_name in self.func_list_exist) and need_generate:
self.output_fd.write(content)
self.func_list_exist.append(func_name)
logging.info('add func:[%s]', func_name)

def gen_comment(self, start_i):
comment_line = ''
# Function comments are on top of function declarations, copy them over
k = start_i - 1 # one line before this func start
if pattern_template.search(self.input_content[k]):
k -= 1
if pattern_comment_2_end.search(self.input_content[k]):
comment_line = self.input_content[k].lstrip()
while not pattern_comment_2_start.search(self.input_content[k]):
k -= 1
comment_line = self.input_content[k].lstrip() + comment_line
else:
for j in range(k, 0, -1):
c_line = self.input_content[j]
if pattern_comment.search(c_line):
c_line = re.sub(r'\s*//', '//', c_line)
comment_line = c_line + comment_line
else:
break
return comment_line

@staticmethod
def implement_function(func):
function_def = ''
function_def += '{\n'

all_items = func.split()
start = 0
return_type = all_items[start]
if return_type == "const":
start += 1
return_type = all_items[start]
if return_type.startswith(('std::map', 'std::set', 'std::vector')):
return_type = "std::map"
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')):
return_type = "Ptr"
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
return_type += "&"
if RETURN_STATEMENTS.__contains__(return_type):
function_def += RETURN_STATEMENTS[return_type]
else:
logging.warning("Unhandled return type[%s]", return_type)

function_def += '\n'
function_def += '}\n'
function_def += '\n'
return function_def


def collect_header_files(path):
"""
:param path:
:return:
"""
header_files = []
shared_includes_content = []
for root, dirs, files in os.walk(path):
files.sort()
for file in files:
if file.find("git") >= 0:
continue
if not file.endswith('.h'):
continue
file_path = os.path.join(root, file)
file_path = file_path.replace('\\', '/')
header_files.append(file_path)
include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:])
shared_includes_content.append(include_str)
return header_files, shared_includes_content


def generate_stub_file(inc_dir, out_cc_dir):
"""
:param inc_dir:
:param out_cc_dir:
:return:
"""
target_header_files, shared_includes_content = collect_header_files(inc_dir)
for header_file in target_header_files:
if not file_endswith_white_list_suffix(header_file):
continue
cc_file = re.sub('.h*$', '.cc', header_file)
h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content)
h_2_cc.h2cc()


def gen_code(inc_dir, out_cc_dir):
"""
:param inc_dir:
:param out_cc_dir:
:return:
"""
if not inc_dir.endswith('/'):
inc_dir += '/'
if not out_cc_dir.endswith('/'):
out_cc_dir += '/'
for include_dir_key_word in include_dir_key_words:
generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir)


if __name__ == '__main__':
inc_dir = sys.argv[1]
out_cc_dir = sys.argv[2]
gen_code(inc_dir, out_cc_dir)

+ 5
- 5
src/common/graph/utils/ge_ir_utils.h View File

@@ -1,18 +1,18 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
*/

#ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_
#define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_


+ 24
- 55
src/common/graph/utils/graph_utils.cc View File

@@ -38,7 +38,6 @@
#include "utils/ge_ir_utils.h"
#include "utils/node_utils.h"
#include "debug/ge_op_types.h"
#include "external/ge/ge_api_types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
@@ -411,8 +410,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTra
/// @return graphStatus
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) {
GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) {
GE_CHECK_NOTNULL(src);
GE_CHECK_NOTNULL(insert_node);

@@ -571,7 +570,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons
static int max_dumpfile_num = 0;
if (max_dumpfile_num == 0) {
string opt = "0";
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt);
(void)GetContext().GetOption("ge.maxDumpFileNum", opt);
max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue);
}
if (max_dumpfile_num != 0 && file_idx > max_dumpfile_num) {
@@ -671,7 +670,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText
if (maxDumpFileSize == 0) {
string opt = "0";
// Can not check return value
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt);
(void)GetContext().GetOption("ge.maxDumpFileSize", opt);
maxDumpFileSize = atol(opt.c_str());
}
if (maxDumpFileSize != 0 && fileSize != -1 && fileSize > maxDumpFileSize) {
@@ -741,7 +740,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn
static int max_dumpfile_num = 0;
if (max_dumpfile_num == 0) {
string opt = "0";
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt);
(void)GetContext().GetOption("ge.maxDumpFileNum", opt);
max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue);
}
if (max_dumpfile_num != 0 && file_index > max_dumpfile_num) {
@@ -921,7 +920,7 @@ graphStatus RelinkDataIO(const NodePtr &node, const std::vector<int> &io_map, In
InNodesToOut GetFullConnectIONodes(const NodePtr &node) {
InNodesToOut in_nodes_to_out;
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Node is nullptr");
GELOGE(GRAPH_FAILED, "Node is nullptr,node is %s", node->GetName().c_str());
return in_nodes_to_out;
}
auto in_nodes_list = node->GetInNodes();
@@ -1309,36 +1308,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCt
return GRAPH_SUCCESS;
}

///
/// Copy all in-data edges from `src_node` to `dst_node`.
/// @param src_node
/// @param dst_node
/// @return
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node,
NodePtr &dst_node) {
if ((src_node == nullptr) || (dst_node == nullptr)) {
GELOGE(GRAPH_FAILED, "Parameter is nullptr");
return GRAPH_PARAM_INVALID;
}
auto src_data_in_nodes = src_node->GetInDataNodes();
if (src_data_in_nodes.empty()) {
return GRAPH_SUCCESS;
}
for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) {
auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx());
auto ret =
GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s",
in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(),
src_node->GetName().c_str(), dst_node->GetName().c_str());
return ret;
}
}
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph,
const NodePtr &node) {
if (graph->AddInputNode(node) == nullptr) {
@@ -1370,7 +1339,7 @@ graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol) {
GE_CHECK_NOTNULL(graph);
for (const auto &node : graph->GetAllNodes()) {
for (auto &node : graph->GetAllNodes()) {
// in_data_anchor
if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) {
GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str());
@@ -1427,16 +1396,16 @@ graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node,
return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol);
}

const std::string &type = node->GetType();
std::string type = node->GetType();
if ((type == MERGE) || (type == STREAMMERGE)) {
return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol);
}

for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn);
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
const std::string &symbol = cur_node_info.ToString();
std::string symbol = cur_node_info.ToString();
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str());
symbol_to_anchors[symbol] = {cur_node_info};
anchor_to_symbol[symbol] = symbol;
@@ -1463,7 +1432,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol) {
GE_CHECK_NOTNULL(node);
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut);
if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) {
continue;
@@ -1477,7 +1446,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node,
return GRAPH_FAILED;
}
} else {
const std::string &symbol = cur_node_info.ToString();
std::string symbol = cur_node_info.ToString();
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str());
symbol_to_anchors.emplace(std::make_pair(symbol, std::list<NodeIndexIO>{cur_node_info}));
anchor_to_symbol.emplace(std::make_pair(symbol, symbol));
@@ -1537,7 +1506,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,
GE_CHECK_NOTNULL(node);
std::vector<NodeIndexIO> exist_node_infos;
std::vector<NodeIndexIO> cur_node_infos;
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
std::string next_name;
@@ -1560,10 +1529,10 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,

size_t anchor_nums = 0;
NodeIndexIO max_node_index_io(nullptr, 0, kOut);
for (const auto &temp_node_info : exist_node_infos) {
for (auto &temp_node_info : exist_node_infos) {
auto iter1 = anchor_to_symbol.find(temp_node_info.ToString());
if (iter1 != anchor_to_symbol.end()) {
const std::string &temp_symbol = iter1->second;
std::string temp_symbol = iter1->second;
auto iter2 = symbol_to_anchors.find(temp_symbol);
if (iter2 != symbol_to_anchors.end()) {
if (iter2->second.size() > anchor_nums) {
@@ -1575,7 +1544,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,
}

std::string symbol;
for (const auto &temp_node_info : exist_node_infos) {
for (auto &temp_node_info : exist_node_infos) {
if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) !=
GRAPH_SUCCESS) ||
symbol.empty()) {
@@ -1587,7 +1556,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,

auto iter = symbol_to_anchors.find(symbol);
if (iter != symbol_to_anchors.end()) {
for (const auto &temp_node_info : cur_node_infos) {
for (auto &temp_node_info : cur_node_infos) {
GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str());
iter->second.emplace_back(temp_node_info);
anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol));
@@ -1615,7 +1584,7 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node,

OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);

@@ -1658,8 +1627,8 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node,
graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2,
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol) {
const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()];
const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()];
std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()];
std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()];
if (symbol1 == symbol2) {
symbol = symbol1;
GELOGI("no need to union.");
@@ -1715,7 +1684,7 @@ graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const
return GRAPH_FAILED;
}

const std::string &symbol = iter1->second;
std::string symbol = iter1->second;
auto iter2 = symbol_to_anchors.find(symbol);
if (iter2 == symbol_to_anchors.end()) {
GE_LOGE("symbol %s not found.", symbol.c_str());
@@ -1743,7 +1712,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t

// pass-through op
NodePtr node = out_data_anchor->GetOwnerNode();
const std::string &type = node->GetType();
std::string type = node->GetType();
const std::set<std::string> pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE};
if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) {
reuse_in_index = output_index;
@@ -1786,7 +1755,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t
uint32_t reuse_input_index = 0;
if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) {
reuse_in_index = static_cast<int32_t>(reuse_input_index);
GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index,
GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index,
reuse_in_index);
return true;
}
@@ -2328,7 +2297,7 @@ void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &
return;
}

std::string name = node->GetName() + "_RetVal_" + std::to_string(index);
std::string name = node->GetName() + "_RetVal";
OpDescPtr ret_val_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, FRAMEWORKOP));
if (ret_val_desc == nullptr) {
error_code = GRAPH_FAILED;


+ 18
- 153
src/common/graph/utils/node_utils.cc View File

@@ -296,18 +296,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer
return GRAPH_FAILED;
}
for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag();
if (!is_unknown_graph) {
output_tensor->SetOriginShape(output_tensor->GetShape());
output_tensor->SetOriginDataType(output_tensor->GetDataType());
}
GeTensorDesc output_tensor = op_desc->GetOutputDesc(out_anchor->GetIdx());
ge::TensorUtils::SetRealDimCnt(output_tensor, static_cast<uint32_t>(output_tensor.GetShape().GetDims().size()));
output_tensor.SetOriginShape(output_tensor.GetShape());
output_tensor.SetOriginDataType(output_tensor.GetDataType());
GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
node_ptr->GetName().c_str(), output_tensor.GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor.GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor.GetOriginDataType()).c_str());
(void)op_desc->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor);
for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
@@ -319,17 +316,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer
continue;
}
GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(),
output_tensor->GetDataType(), output_tensor->GetOriginDataType());
peer_input_desc->SetShape(output_tensor->GetShape());
peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
peer_input_desc->SetDataType(output_tensor->GetDataType());
peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(),
output_tensor.GetDataType(), output_tensor.GetOriginDataType());
peer_input_desc->SetShape(output_tensor.GetShape());
peer_input_desc->SetOriginShape(output_tensor.GetOriginShape());
peer_input_desc->SetDataType(output_tensor.GetDataType());
peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType());
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)output_tensor->GetShapeRange(shape_range);
(void)output_tensor.GetShapeRange(shape_range);
peer_input_desc->SetShapeRange(shape_range);
ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
static_cast<uint32_t>(output_tensor.GetShape().GetDims().size()));
GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(),
peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType());
@@ -404,13 +401,10 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const
graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
auto desc = node.GetOpDesc();
GE_CHECK_NOTNULL(desc);
// check self
is_unknow = OpShapeIsUnknown(desc);
if (is_unknow) {
return GRAPH_SUCCESS;
}

auto sub_graph_names = desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
is_unknow = OpShapeIsUnknown(desc);
return GRAPH_SUCCESS;
} else {
auto owner_graph = node.GetOwnerComputeGraph();
@@ -561,53 +555,6 @@ NodePtr NodeUtils::GetParentInput(const NodePtr &node) {
return peer_out_anchor->GetOwnerNode();
}

///
/// @brief Check is varying_input for while node
/// @param [in] node: Data node for subgraph
/// @return bool
///
bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) {
if (node == nullptr) {
return false;
}
if (node->GetType() != DATA) {
return false; // not input_node for subgraph
}

const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode();
if (parent_node == nullptr) {
return false; // root graph
}

if (kWhileOpTypes.count(parent_node->GetType()) == 0) {
return false; // not input_node for while subgraph
}

uint32_t index_i = 0;
if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) {
GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str());
return false;
}
bool varying_flag = true;
for (const auto &item : node->GetOutDataNodesAndAnchors()) {
if (item.first->GetType() != NETOUTPUT) {
continue;
}
OpDescPtr op_desc = item.first->GetOpDesc();
uint32_t index_o = 0;
if ((op_desc == nullptr) ||
!AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) {
continue; // input for while-cond subgraph
}
if (index_i != index_o) {
continue; // varying input for while-body subgraph
}
varying_flag = false;
break;
}
return varying_flag;
}

///
/// @brief Get subgraph input is constant.
/// @param [in] node
@@ -690,86 +637,4 @@ Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {

return GRAPH_SUCCESS;
}
///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) {
vector<NodePtr> in_data_node_vec;
auto op_desc = node.GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
GELOGW("Node %s is single node without sub graph.", node.GetName().c_str());
return in_data_node_vec;
}
auto compute_graph = node.GetOwnerComputeGraph();
for (const std::string &instance_name : subgraph_names) {
auto subgraph = compute_graph->GetSubgraph(instance_name);
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
int parent_index = 0;
if (NodeUtils::IsSubgraphInput(node_in_subgraph)) {
(void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index);
if (parent_index == index) {
in_data_node_vec.emplace_back(node_in_subgraph);
}
}
}
}
return in_data_node_vec;
}
///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) {
vector<NodePtr> out_data_node_vec;
auto op_desc = node.GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
GELOGI("Node %s is single node without sub graph.", node.GetName().c_str());
return out_data_node_vec;
}
auto compute_graph = node.GetOwnerComputeGraph();
for (const std::string &instance_name : subgraph_names) {
auto subgraph = compute_graph->GetSubgraph(instance_name);
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) {
out_data_node_vec.emplace_back(node_in_subgraph);
}
}
}
return out_data_node_vec;
}

NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) {
if (node.GetInDataAnchor(index) == nullptr) {
return nullptr;
}
if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) {
return nullptr;
}
return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode();
}

vector<NodePtr> NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) {
vector<NodePtr> out_data_nodes;
auto out_data_anchor = node.GetOutDataAnchor(index);
if (out_data_anchor == nullptr) {
return out_data_nodes;
}
for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
if (peer_in_anchor == nullptr) {
continue;
}
if (peer_in_anchor->GetOwnerNode() == nullptr) {
continue;
}
out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode());
}
return out_data_nodes;
}
} // namespace ge

+ 20
- 46
src/common/graph/utils/op_desc_utils.cc View File

@@ -197,33 +197,24 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::
continue;
}
auto in_node = out_anchor->GetOwnerNode();
while (true) {
if (in_node == nullptr) {
break;
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
ret.push_back(in_node);
} else if (in_node->GetType() == DATA) {
const ComputeGraphPtr &graph = node.GetOwnerComputeGraph();
GE_CHK_BOOL_EXEC(graph != nullptr, continue, "Owner graph is null");

const NodePtr &parent_node = graph->GetParentNode();
if (parent_node == nullptr) {
continue; // Root graph.
}
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) {
ret.push_back(in_node);
break;
} else if (in_node->GetType() == DATA) {
if (NodeUtils::IsWhileVaryingInput(in_node)) {
break;
}
in_node = NodeUtils::GetParentInput(in_node);
} else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) {
bool is_constant = false;
(void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant);
if (!is_constant) {
break;
}
// Enter node has and only has one input
if (in_node->GetInDataNodes().size() != 1) {
GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(),
in_node->GetInDataNodes().size());
break;
}
in_node = in_node->GetInDataNodes().at(0);
} else {
break;

if (kWhileOpTypes.count(parent_node->GetType()) > 0) {
continue; // Subgraph of While cond or body.
}

NodePtr input_node = NodeUtils::GetParentInput(in_node);
if ((input_node != nullptr) && ((input_node->GetType() == CONSTANT) || (input_node->GetType() == CONSTANTOP))) {
ret.push_back(input_node);
}
}
}
@@ -444,27 +435,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) {
vector<GeTensorPtr> ret;
auto op_desc = node.GetOpDesc();
GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!");
// Place holder operator, try to get the weight from parent node
// when parent node is const operator
if (node.GetType() == PLACEHOLDER) {
std::string parent_op;
(void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op);
// This if judgment is necessary because the current subgraph optimization is multithreaded
// and the parent node of the PLD operation should be a stable type, such as const
if (parent_op == CONSTANT || parent_op == CONSTANTOP) {
NodePtr parent_node = nullptr;
parent_node = op_desc->TryGetExtAttr("parentNode", parent_node);
if (parent_node != nullptr) {
op_desc = parent_node->GetOpDesc();
GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str());
}
}
}
GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return ret, "node.GetOpDesc is nullptr!");
// Const operator, take the weight directly
if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) {
auto weight = MutableWeights(op_desc);
if (node.GetOpDesc()->GetType() == CONSTANT || (node.GetOpDesc()->GetType() == CONSTANTOP)) {
auto weight = MutableWeights(node.GetOpDesc());
if (weight == nullptr) {
GELOGI("const op has no weight, op name:%s", node.GetName().c_str());
return ret;


+ 2
- 6
src/common/graph/utils/tensor_utils.cc View File

@@ -19,7 +19,6 @@

#include "debug/ge_log.h"
#include "framework/common/debug/ge_log.h"
#include "common/util/error_manager/error_manager.h"
#include "graph/ge_tensor.h"
#include "graph/types.h"
#include "graph/utils/type_utils.h"
@@ -106,10 +105,7 @@ static graphStatus CalcElementCntByDims(const std::vector<int64_t> &dims, int64_
element_cnt = 1;
for (int64_t dim : dims) {
if (CheckMultiplyOverflowInt64(element_cnt, dim)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19013", {"function", "var1", "var2"},
{"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)});
GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim);
GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, as when multiplying %ld and %ld.", element_cnt, dim);
return GRAPH_FAILED;
}
element_cnt *= dim;
@@ -277,6 +273,7 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format
case FORMAT_FRACTAL_Z:
graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt);
break;
case FORMAT_NC1HWC0_C04:
case FORMAT_FRACTAL_NZ:
case FORMAT_FRACTAL_ZZ:
case FORMAT_NDHWC:
@@ -288,7 +285,6 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format
case FORMAT_NDC1HWC0:
case FORMAT_FRACTAL_Z_C04:
case FORMAT_FRACTAL_ZN_LSTM:
case FORMAT_NC1HWC0_C04:
graph_status = CalcElementCntByDims(dims, element_cnt);
break;
default:


+ 1
- 2
src/common/graph/utils/type_utils.cc View File

@@ -147,8 +147,7 @@ static const std::map<std::string, Format> kStringToFormatMap = {
{"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM},
{"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G},
{"FORMAT_RESERVED", FORMAT_RESERVED},
{"ALL", FORMAT_ALL},
{"NULL", FORMAT_NULL}};
{"ALL", FORMAT_ALL}};

static const std::map<DataType, std::string> kDataTypeToStringMap = {
{DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set.


Loading…
Cancel
Save