@@ -71,6 +71,5 @@ target_link_libraries(graph PRIVATE | |||
${PROTOBUF_LIBRARY} | |||
${c_sec} | |||
${slog} | |||
${error_manager} | |||
rt | |||
dl) |
@@ -106,15 +106,6 @@ ComputeGraph::Vistor<NodePtr> ComputeGraph::AllGraphNodes(std::vector<std::share | |||
return Vistor<NodePtr>(shared_from_this(), all_nodes); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetNodes( | |||
bool is_unknown_shape) const { | |||
if (is_unknown_shape) { | |||
return GetDirectNode(); | |||
} else { | |||
return GetAllNodes(); | |||
} | |||
} | |||
size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const { | |||
@@ -506,10 +497,6 @@ ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<Compute | |||
if (name != subgraph->GetName()) { | |||
GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); | |||
} | |||
if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) { | |||
GE_LOGE("The subgraph %s existed", name.c_str()); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
sub_graph_.push_back(subgraph); | |||
names_to_subgraph_[name] = subgraph; | |||
return GRAPH_SUCCESS; | |||
@@ -34,16 +34,12 @@ GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); | |||
GE_REGISTER_OPTYPE(SWITCH, "Switch"); | |||
GE_REGISTER_OPTYPE(MERGE, "Merge"); | |||
GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); | |||
GE_REGISTER_OPTYPE(ENTER, "Enter"); | |||
GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); | |||
GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); | |||
GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); | |||
GE_REGISTER_OPTYPE(CONSTANT, "Const"); | |||
GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); | |||
GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); | |||
GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); | |||
GE_REGISTER_OPTYPE(INITDATA, "InitData"); | |||
GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); | |||
GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); | |||
GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); | |||
@@ -41,9 +41,11 @@ using namespace ge; | |||
using namespace std; | |||
namespace ge { | |||
namespace { | |||
const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; | |||
const string kIsGraphInferred = "_is_graph_inferred"; | |||
RefRelations reflection_builder; | |||
static const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; | |||
static bool net_format_is_nd = true; | |||
static Format g_user_set_format = FORMAT_ND; | |||
static bool is_first_infer = true; | |||
static RefRelations reflection_builder; | |||
} // namespace | |||
graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection, | |||
@@ -70,49 +72,9 @@ graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &re | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) { | |||
// 5 meas dim num | |||
if (node_ptr->GetType() != "BiasAdd") { | |||
return GRAPH_SUCCESS; | |||
} | |||
std::unordered_map<string, Format> kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}}; | |||
for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) { | |||
auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i); | |||
GE_CHECK_NOTNULL(in_desc); | |||
if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num | |||
continue; | |||
} | |||
auto format = in_desc->GetOriginFormat(); | |||
auto key = TypeUtils::FormatToSerialString(format); | |||
auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; | |||
in_desc->SetOriginFormat(fixed_format); | |||
in_desc->SetFormat(fixed_format); | |||
GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i, | |||
node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), | |||
TypeUtils::FormatToSerialString(fixed_format).c_str()); | |||
} | |||
for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) { | |||
auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i); | |||
GE_CHECK_NOTNULL(out_desc); | |||
if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num | |||
continue; | |||
} | |||
auto format = out_desc->GetOriginFormat(); | |||
auto key = TypeUtils::FormatToSerialString(format); | |||
auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; | |||
out_desc->SetOriginFormat(fixed_format); | |||
out_desc->SetFormat(fixed_format); | |||
GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i, | |||
node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), | |||
TypeUtils::FormatToSerialString(fixed_format).c_str()); | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { | |||
GE_CHECK_NOTNULL(graph); | |||
graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { | |||
GE_CHECK_NOTNULL(op_desc); | |||
if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) { | |||
if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) { | |||
ConstGeTensorPtr tensor_value; | |||
if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { | |||
GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str()); | |||
@@ -133,7 +95,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
} | |||
anchor_points.clear(); | |||
// Get all anchor point nodes and switch nodes | |||
for (auto &node_ptr : graph->GetAllNodes()) { | |||
for (const auto &node_ptr : graph->GetAllNodes()) { | |||
if (node_ptr == nullptr) { | |||
return GRAPH_FAILED; | |||
} | |||
@@ -141,7 +103,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
if (op_desc == nullptr) { | |||
return GRAPH_FAILED; | |||
} | |||
graphStatus status = RefreshConstantOutProcess(graph, op_desc); | |||
graphStatus status = RefreshConstantOutProcess(op_desc); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "refresh constant out process failed!"); | |||
return GRAPH_FAILED; | |||
@@ -173,16 +135,6 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
if (!node_is_all_nd) { | |||
continue; | |||
} | |||
// special process for biasAdd op | |||
// In tensorflow, biasAdd's format is alwayse NHWC even though set the arg | |||
// "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism | |||
// so here do special process | |||
status = BiasAddFormatFixProcess(node_ptr); | |||
if (status != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "fix biasAdd process failed!"); | |||
return GRAPH_FAILED; | |||
} | |||
GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str()); | |||
anchor_points.push_back(node_ptr); | |||
} | |||
@@ -392,11 +344,14 @@ void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor | |||
} | |||
} | |||
graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes, | |||
ge::Format data_format, | |||
void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = is_first; } | |||
graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | |||
std::unordered_map<ge::NodePtr, bool> &node_status) { | |||
if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { | |||
GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph), | |||
bool is_internal_format = TypeUtils::IsInternalFormat(data_format); | |||
bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND); | |||
if (!need_process) { | |||
GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer, | |||
TypeUtils::FormatToSerialString(data_format).c_str()); | |||
return GRAPH_SUCCESS; | |||
} | |||
@@ -455,6 +410,8 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||
std::vector<ge::NodePtr> anchor_points; | |||
std::vector<ge::NodePtr> data_nodes; | |||
// global net format | |||
net_format_is_nd = true; | |||
g_user_set_format = FORMAT_ND; | |||
if (graph == nullptr) { | |||
GELOGE(GRAPH_FAILED, "input graph is null"); | |||
@@ -491,15 +448,10 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||
/// format for these data nodes. | |||
/// Notice: ignore 5D formats | |||
auto data_format = graph->GetDataFormat(); | |||
status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status); | |||
(void)AttrUtils::SetBool(graph, kIsGraphInferred, true); | |||
status = DataNodeFormatProcess(data_nodes, data_format, node_status); | |||
// Set infer flag to false | |||
SetInferOrigineFormatFlag(false); | |||
return status; | |||
} | |||
bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { | |||
bool is_graph_inferred = false; | |||
return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); | |||
} | |||
} // namespace ge |
@@ -30,9 +30,10 @@ namespace ge { | |||
class FormatRefiner { | |||
public: | |||
static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); | |||
static void SetInferOrigineFormatFlag(bool is_first = true); | |||
private: | |||
static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | |||
static graphStatus RefreshConstantOutProcess(const OpDescPtr &op_desc); | |||
static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points, | |||
std::vector<ge::NodePtr> &data_nodes, | |||
std::unordered_map<ge::NodePtr, bool> &node_status); | |||
@@ -42,9 +43,8 @@ class FormatRefiner { | |||
std::unordered_map<ge::NodePtr, bool> &node_status); | |||
static graphStatus ForwardInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node, | |||
std::unordered_map<ge::NodePtr, bool> &node_status); | |||
static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes, | |||
ge::Format data_format, std::unordered_map<ge::NodePtr, bool> &node_status); | |||
static bool IsGraphInferred(const ComputeGraphPtr &graph); | |||
static graphStatus DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | |||
std::unordered_map<ge::NodePtr, bool> &node_status); | |||
}; | |||
} // namespace ge | |||
#endif // COMMON_GRAPH_FORMAT_REFINER_H_ |
@@ -725,10 +725,6 @@ const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; | |||
const std::string ATTR_MODEL_CORE_TYPE = "core_type"; | |||
const std::string ATTR_MODEL_ATC_VERSION = "atc_version"; | |||
const std::string ATTR_MODEL_OPP_VERSION = "opp_version"; | |||
// Public attribute | |||
const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; | |||
@@ -938,7 +934,7 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; | |||
const std::string MODEL_ATTR_SESSION_ID = "session_id"; | |||
// lx fusion | |||
// l1 fusion and other fusion in future | |||
const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; | |||
const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; | |||
const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | |||
@@ -952,17 +948,9 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1 | |||
const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; | |||
const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; | |||
const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; | |||
const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; | |||
const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; | |||
const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; | |||
const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; | |||
const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; | |||
const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; | |||
const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; | |||
// Op debug attrs | |||
const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; | |||
const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; | |||
// Atomic addr clean attrs | |||
const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | |||
@@ -1027,11 +1015,4 @@ const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; | |||
// used for allreduce tailing optimization | |||
const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; | |||
const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; | |||
// dynamic shape attr | |||
const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr"; | |||
const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index"; | |||
// for fusion op plugin | |||
const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; | |||
} // namespace ge |
@@ -220,7 +220,6 @@ const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; | |||
const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; | |||
const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; | |||
const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; | |||
const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index"; | |||
GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} | |||
@@ -568,16 +567,6 @@ DataType GeTensorDesc::GetOriginDataType() const { | |||
return TypeUtils::SerialStringToDataType(origin_data_type_str); | |||
} | |||
std::vector<uint32_t> GeTensorDesc::GetRefPortIndex() const { | |||
vector<uint32_t> ref_port_index; | |||
(void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index); | |||
return ref_port_index; | |||
} | |||
void GeTensorDesc::SetRefPortByIndex(const std::vector<uint32_t> &index) { | |||
(void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index); | |||
} | |||
graphStatus GeTensorDesc::IsValid() const { | |||
auto dtype = this->GetDataType(); | |||
auto format = this->GetFormat(); | |||
@@ -210,7 +210,7 @@ class GraphImpl { | |||
graphStatus FindOpByName(const string &name, ge::Operator &op) const { | |||
auto it = op_list_.find(name); | |||
GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str()); | |||
GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "Error: there is no op: %s.", name.c_str()); | |||
op = it->second; | |||
return GRAPH_SUCCESS; | |||
} | |||
@@ -77,7 +77,6 @@ LOCAL_SHARED_LIBRARIES := \ | |||
libc_sec \ | |||
libprotobuf \ | |||
libslog \ | |||
liberror_manager \ | |||
LOCAL_LDFLAGS := -lrt -ldl | |||
@@ -95,35 +94,10 @@ LOCAL_CPPFLAGS += -fexceptions | |||
LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
LOCAL_SRC_FILES := \ | |||
../../out/graph/lib64/stub/graph.cc \ | |||
../../out/graph/lib64/stub/operator.cc \ | |||
../../out/graph/lib64/stub/tensor.cc \ | |||
../../out/graph/lib64/stub/operator_factory.cc \ | |||
LOCAL_SHARED_LIBRARIES := | |||
LOCAL_LDFLAGS := -lrt -ldl | |||
LOCAL_MULTILIB := 64 | |||
LOCAL_PROPRIETARY_MODULE := true | |||
include $(BUILD_HOST_SHARED_LIBRARY) | |||
#compiler for host | |||
include $(CLEAR_VARS) | |||
LOCAL_MODULE := fwk_stub/libgraph | |||
LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 | |||
LOCAL_CPPFLAGS += -fexceptions | |||
LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
LOCAL_SRC_FILES := \ | |||
../../out/graph/lib64/stub/attr_value.cc \ | |||
../../out/graph/lib64/stub/graph.cc \ | |||
../../out/graph/lib64/stub/operator.cc \ | |||
../../out/graph/lib64/stub/operator_factory.cc \ | |||
../../out/graph/lib64/stub/tensor.cc \ | |||
../../out/atc/lib64/stub/graph.cc \ | |||
../../out/atc/lib64/stub/operator.cc \ | |||
../../out/atc/lib64/stub/tensor.cc \ | |||
../../out/atc/lib64/stub/operator_factory.cc \ | |||
LOCAL_SHARED_LIBRARIES := | |||
@@ -148,7 +122,6 @@ LOCAL_SHARED_LIBRARIES := \ | |||
libc_sec \ | |||
libprotobuf \ | |||
libslog \ | |||
liberror_manager \ | |||
LOCAL_LDFLAGS := -lrt -ldl | |||
@@ -169,38 +142,10 @@ LOCAL_CFLAGS += -O2 | |||
LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
LOCAL_SRC_FILES := \ | |||
../../out/graph/lib64/stub/graph.cc \ | |||
../../out/graph/lib64/stub/operator.cc \ | |||
../../out/graph/lib64/stub/tensor.cc \ | |||
../../out/graph/lib64/stub/operator_factory.cc \ | |||
LOCAL_SHARED_LIBRARIES := | |||
LOCAL_LDFLAGS := -lrt -ldl | |||
ifeq ($(device_os),android) | |||
LOCAL_LDFLAGS := -ldl | |||
endif | |||
LOCAL_MULTILIB := 64 | |||
LOCAL_PROPRIETARY_MODULE := true | |||
include $(BUILD_SHARED_LIBRARY) | |||
#compiler for device | |||
include $(CLEAR_VARS) | |||
LOCAL_MODULE := fwk_stub/libgraph | |||
LOCAL_CFLAGS += -O2 | |||
LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
LOCAL_SRC_FILES := \ | |||
../../out/graph/lib64/stub/attr_value.cc \ | |||
../../out/graph/lib64/stub/graph.cc \ | |||
../../out/graph/lib64/stub/operator.cc \ | |||
../../out/graph/lib64/stub/operator_factory.cc \ | |||
../../out/graph/lib64/stub/tensor.cc \ | |||
../../out/atc/lib64/stub/graph.cc \ | |||
../../out/atc/lib64/stub/operator.cc \ | |||
../../out/atc/lib64/stub/tensor.cc \ | |||
../../out/atc/lib64/stub/operator_factory.cc \ | |||
LOCAL_SHARED_LIBRARIES := | |||
@@ -229,7 +174,6 @@ LOCAL_SHARED_LIBRARIES := \ | |||
libc_sec \ | |||
libprotobuf \ | |||
libslog \ | |||
liberror_manager \ | |||
LOCAL_LDFLAGS := -lrt -ldl | |||
@@ -255,7 +199,6 @@ LOCAL_STATIC_LIBRARIES := \ | |||
LOCAL_SHARED_LIBRARIES := \ | |||
libc_sec \ | |||
libslog \ | |||
liberror_manager \ | |||
LOCAL_LDFLAGS := -lrt -ldl | |||
@@ -279,7 +222,6 @@ LOCAL_STATIC_LIBRARIES := \ | |||
LOCAL_SHARED_LIBRARIES := \ | |||
libc_sec \ | |||
libslog \ | |||
liberror_manager \ | |||
LOCAL_LDFLAGS := -lrt -ldl | |||
@@ -88,8 +88,10 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_ | |||
} | |||
bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { | |||
GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null."); | |||
GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); | |||
if (op_desc == nullptr || op_def_proto == nullptr) { | |||
GELOGE(GRAPH_FAILED, "Input Para Invalid"); | |||
return false; | |||
} | |||
if (op_desc->op_def_.GetProtoMsg() != nullptr) { | |||
*op_def_proto = *op_desc->op_def_.GetProtoMsg(); | |||
// Delete unnecessary attr | |||
@@ -128,17 +130,16 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op | |||
for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { | |||
op_def_proto->add_subgraph_name(name); | |||
} | |||
if (!op_desc->output_name_idx_.empty()) { | |||
proto::AttrDef key; | |||
proto::AttrDef value; | |||
for (auto &item : op_desc->output_name_idx_) { | |||
key.mutable_list()->add_s(item.first); | |||
value.mutable_list()->add_i(item.second); | |||
} | |||
auto op_desc_attr = op_def_proto->mutable_attr(); | |||
op_desc_attr->insert({"_output_name_key", key}); | |||
op_desc_attr->insert({"_output_name_value", value}); | |||
proto::AttrDef key; | |||
proto::AttrDef value; | |||
for (auto &item : op_desc->output_name_idx_) { | |||
key.mutable_list()->add_s(item.first); | |||
value.mutable_list()->add_i(item.second); | |||
} | |||
auto op_desc_attr = op_def_proto->mutable_attr(); | |||
op_desc_attr->insert({"_output_name_key", key}); | |||
op_desc_attr->insert({"_output_name_value", value}); | |||
} | |||
return true; | |||
} | |||
@@ -26,7 +26,6 @@ | |||
#include "utils/ge_ir_utils.h" | |||
#include "utils/node_utils.h" | |||
#include "utils/op_desc_utils.h" | |||
#include "common/util/error_manager/error_manager.h" | |||
using std::string; | |||
using std::vector; | |||
@@ -155,7 +154,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(cons | |||
const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); | |||
const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); | |||
if (peer_node == nullptr || r_peer_node == nullptr) { | |||
GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", | |||
GELOGE(GRAPH_FAILED, "Error: anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", | |||
this->GetName().c_str(), i, j); | |||
return false; | |||
} | |||
@@ -435,11 +434,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::Get | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const { | |||
if (idx < 0 || idx >= static_cast<int>(in_data_anchors_.size())) { | |||
ErrorManager::GetInstance().ATCReportErrMessage( | |||
"E19019", {"opname", "index", "anchorname", "optype"}, | |||
{GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()}); | |||
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx, | |||
GetType().c_str()); | |||
GELOGE(GRAPH_FAILED, "the node doesn't have %d th in_data_anchor, node %s:%s", idx, GetType().c_str(), | |||
GetName().c_str()); | |||
return nullptr; | |||
} else { | |||
return in_data_anchors_[idx]; | |||
@@ -449,10 +445,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAn | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const { | |||
// Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ | |||
if (idx < -1 || idx >= static_cast<int>(in_data_anchors_.size())) { | |||
ErrorManager::GetInstance().ATCReportErrMessage( | |||
"E19019", {"opname", "index", "anchorname", "optype"}, | |||
{GetName().c_str(), std::to_string(idx), "in_anchor", GetType().c_str()}); | |||
GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); | |||
GELOGW("the node doesn't have %d th in_anchor, node %s:%s", idx, GetType().c_str(), GetName().c_str()); | |||
return nullptr; | |||
} else { | |||
// Return control anchor | |||
@@ -468,15 +461,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int i | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const { | |||
// Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ | |||
if (idx < -1 || idx >= static_cast<int>(out_data_anchors_.size())) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"}, | |||
{ | |||
GetName().c_str(), | |||
std::to_string(idx), | |||
"out_anchor", | |||
GetType().c_str(), | |||
}); | |||
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx, | |||
GetType().c_str()); | |||
GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_anchor, node %s:%s", idx, GetType().c_str(), | |||
GetName().c_str()); | |||
return nullptr; | |||
} else { | |||
// Return control anchor | |||
@@ -491,11 +477,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const { | |||
if (idx < 0 || idx >= static_cast<int>(out_data_anchors_.size())) { | |||
ErrorManager::GetInstance().ATCReportErrMessage( | |||
"E19019", {"opname", "index", "anchorname", "optype"}, | |||
{GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()}); | |||
GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx, | |||
GetType().c_str()); | |||
GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_data_anchor, node %s:%s", idx, GetType().c_str(), | |||
GetName().c_str()); | |||
return nullptr; | |||
} else { | |||
return out_data_anchors_[idx]; | |||
@@ -750,15 +733,11 @@ graphStatus Node::Verify() const { | |||
GELOGW("in anchor ptr is null"); | |||
continue; | |||
} | |||
bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type || | |||
op_->GetType() == const_type || op_->GetType() == variable_type || | |||
op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0; | |||
if (!valid_anchor) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"name", "index"}, | |||
{GetName(), std::to_string(in_anchor_ptr->GetIdx())}); | |||
GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); | |||
return GRAPH_FAILED; | |||
} | |||
GE_CHK_BOOL_RET_STATUS( | |||
op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type || | |||
op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || | |||
in_anchor_ptr->GetPeerAnchors().size() > 0, | |||
GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); | |||
} | |||
string frameworkop_type = "FrameworkOp"; | |||
@@ -19,7 +19,6 @@ | |||
#include "debug/ge_util.h" | |||
#include "external/graph/operator.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "common/util/error_manager/error_manager.h" | |||
#include "graph/ge_attr_value.h" | |||
#include "graph/ge_tensor.h" | |||
#include "graph/operator_factory_impl.h" | |||
@@ -471,25 +470,6 @@ GeTensorDesc OpDesc::GetInputDesc(const string &name) const { | |||
return *(inputs_desc_[it->second].get()); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { | |||
GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); | |||
if (inputs_desc_[index] == nullptr) { | |||
return nullptr; | |||
} | |||
GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid"); | |||
return inputs_desc_[index]; | |||
} | |||
GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const { | |||
auto input_name_idx = GetAllInputName(); | |||
auto it = input_name_idx.find(name); | |||
if (it == input_name_idx.end()) { | |||
GELOGW("Failed to get [%s] input desc", name.c_str()); | |||
return nullptr; | |||
} | |||
return MutableInputDesc(it->second); | |||
} | |||
GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const { | |||
auto input_name_idx = GetAllInputName(); | |||
vector<string> names; | |||
@@ -516,6 +496,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(cons | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; } | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { | |||
GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); | |||
if (inputs_desc_[index] == nullptr) { | |||
return nullptr; | |||
} | |||
GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid"); | |||
return inputs_desc_[index]; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllInputsDesc() const { | |||
vector<GeTensorDesc> temp{}; | |||
for (const auto &it : inputs_desc_) { | |||
@@ -620,15 +609,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu | |||
return outputs_desc_[index]; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const { | |||
auto it = output_name_idx_.find(name); | |||
if (it == output_name_idx_.end()) { | |||
GELOGW("Failed to get [%s] output desc", name.c_str()); | |||
return nullptr; | |||
} | |||
return MutableOutputDesc(it->second); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { | |||
return static_cast<uint32_t>(outputs_desc_.size()); | |||
} | |||
@@ -902,22 +882,15 @@ graphStatus OpDesc::CommonVerify() const { | |||
// Checking shape of all inputs | |||
vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims(); | |||
for (int64_t dim : ishape) { | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
dim < -2, ErrorManager::GetInstance().ATCReportErrMessage( | |||
"E19014", {"opname", "value", "reason"}, | |||
{GetName(), "input " + iname + " shape", "contains negative or zero dimension"}); | |||
return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), | |||
iname.c_str()); | |||
GE_CHK_BOOL_RET_STATUS(dim >= -2, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", | |||
iname.c_str()); | |||
} | |||
} | |||
// Check all attributes defined | |||
const auto &all_attributes = GetAllAttrs(); | |||
for (const auto &name : GetAllAttrNames()) { | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
all_attributes.find(name) == all_attributes.end(), | |||
ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, | |||
{GetName(), "attribute " + name, "is empty"}); | |||
return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str()); | |||
GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, | |||
"operator attribute %s is empty.", name.c_str()); | |||
} | |||
return GRAPH_SUCCESS; | |||
@@ -36,8 +36,6 @@ | |||
#include "graph/op_desc.h" | |||
#include "graph/runtime_inference_context.h" | |||
#include "graph/usr_types.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "utils/graph_utils.h" | |||
#include "utils/op_desc_utils.h" | |||
#include "utils/tensor_adapter.h" | |||
@@ -59,7 +57,8 @@ using std::vector; | |||
namespace ge { | |||
class OpIO { | |||
public: | |||
OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {} | |||
explicit OpIO(const string &name, int index, const OperatorImplPtr &owner) | |||
: name_(name), index_(index), owner_(owner) {} | |||
~OpIO() = default; | |||
@@ -547,46 +546,56 @@ Operator &Operator::AddControlInput(const Operator &src_oprt) { | |||
} | |||
graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const { | |||
GE_CHECK_NOTNULL(operator_impl_); | |||
auto node_ptr = operator_impl_->GetNode(); | |||
if (node_ptr != nullptr) { | |||
if (operator_impl_ == nullptr) { | |||
GELOGE(GRAPH_FAILED, "operator impl is nullptr."); | |||
return GRAPH_FAILED; | |||
} | |||
ge::ConstNodePtr node_ptr = operator_impl_->GetNode(); | |||
if (node_ptr) { | |||
// For inner compute graph | |||
auto op_desc = node_ptr->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
if (op_desc == nullptr) { | |||
GELOGE(GRAPH_FAILED, "op_desc is nullptr."); | |||
return GRAPH_FAILED; | |||
} | |||
auto index = op_desc->GetInputIndexByName(dst_name); | |||
auto in_data_anchor = node_ptr->GetInDataAnchor(index); | |||
GE_CHECK_NOTNULL(in_data_anchor); | |||
if (in_data_anchor == nullptr) { | |||
GELOGE(GRAPH_FAILED, "in_data_anchor is nullptr."); | |||
return GRAPH_FAILED; | |||
} | |||
auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(out_data_anchor); | |||
auto peer_node = out_data_anchor->GetOwnerNode(); | |||
GE_CHECK_NOTNULL(peer_node); | |||
auto peer_op_desc = peer_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(peer_op_desc); | |||
auto peer_op_type = peer_op_desc->GetType(); | |||
if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { | |||
auto const_op_impl = ComGraphMakeShared<OperatorImpl>(peer_node); | |||
GE_CHECK_NOTNULL(const_op_impl); | |||
Operator const_op(std::move(const_op_impl)); | |||
return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); | |||
} else if (peer_op_type == DATA) { | |||
auto parent_node = NodeUtils::GetParentInput(peer_node); | |||
while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { | |||
parent_node = NodeUtils::GetParentInput(parent_node); | |||
} | |||
if ((parent_node != nullptr) && | |||
((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { | |||
auto const_op_impl = ComGraphMakeShared<OperatorImpl>(parent_node); | |||
GE_CHECK_NOTNULL(const_op_impl); | |||
Operator const_op(std::move(const_op_impl)); | |||
return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); | |||
if (out_data_anchor == nullptr) { | |||
GELOGE(GRAPH_FAILED, "out_data_anchor is nullptr."); | |||
return GRAPH_FAILED; | |||
} | |||
std::shared_ptr<Node> peer_node_ptr = out_data_anchor->GetOwnerNode(); | |||
if (peer_node_ptr == nullptr) { | |||
GELOGE(GRAPH_FAILED, "peer_node_ptr is nullptr."); | |||
return GRAPH_FAILED; | |||
} | |||
ge::OperatorImplPtr operator_impl_ptr = nullptr; | |||
operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(peer_node_ptr); | |||
if (operator_impl_ptr == nullptr) { | |||
GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); | |||
return GRAPH_FAILED; | |||
} | |||
Operator const_op(std::move(operator_impl_ptr)); | |||
if (peer_node_ptr->GetOpDesc() != nullptr) { | |||
const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); | |||
if (op_descType == CONSTANTOP) { | |||
return const_op.GetAttr(op::Constant::name_attr_value(), data); | |||
} else if (op_descType == CONSTANT) { | |||
return const_op.GetAttr(op::Const::name_attr_value(), data); | |||
} | |||
} | |||
// Try get from runtime inference context | |||
auto session_id = std::to_string(GetContext().SessionId()); | |||
RuntimeInferenceContext *runtime_infer_ctx = nullptr; | |||
if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { | |||
GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); | |||
auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); | |||
auto ret = runtime_infer_ctx->GetTensor(peer_node_ptr->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); | |||
if (ret == GRAPH_SUCCESS) { | |||
return GRAPH_SUCCESS; | |||
} | |||
@@ -595,8 +604,6 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co | |||
// For outer graph | |||
return GetInputConstDataOut(dst_name, data); | |||
} | |||
auto op_name = operator_impl_->GetName(); | |||
GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const { | |||
@@ -85,8 +85,6 @@ uint32_t GEContext::DeviceId() { return device_id_; } | |||
uint64_t GEContext::TraceId() { return trace_id_; } | |||
void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; } | |||
void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | |||
} // namespace ge |
@@ -242,10 +242,6 @@ void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &r | |||
int sub_graph_idx = 0; | |||
for (const auto &name : sub_graph_names) { | |||
auto sub_graph = root_graph.GetSubgraph(name); | |||
if (sub_graph == nullptr) { | |||
GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str()); | |||
continue; | |||
} | |||
for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { | |||
auto sub_graph_node_type = sub_graph_node->GetType(); | |||
@@ -37,115 +37,6 @@ | |||
namespace ge { | |||
namespace { | |||
const uint32_t kWhileBodySubGraphIdx = 1; | |||
graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { | |||
GELOGD("Enter reverse brush while body subgraph process!"); | |||
auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); | |||
if (sub_graph_body == nullptr) { | |||
GELOGE(GRAPH_FAILED, "Get while body graph failed!"); | |||
return GRAPH_FAILED; | |||
} | |||
for (const auto &node_sub : sub_graph_body->GetAllNodes()) { | |||
if (node_sub->GetInDataNodes().size() == 0) { | |||
continue; | |||
} | |||
for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) { | |||
auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i); | |||
(void)input_desc->SetUnknownDimNumShape(); | |||
} | |||
for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) { | |||
auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i); | |||
(void)output_desc->SetUnknownDimNumShape(); | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, | |||
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) { | |||
GELOGD("Enter update parent node shape for class branch op process"); | |||
// check sub_graph shape.If not same ,do unknown shape process | |||
for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||
if (ref_out_tensors[i].empty()) { | |||
continue; | |||
} | |||
auto ref_out_tensor = ref_out_tensors[i].at(0); | |||
ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); | |||
for (auto &tensor : ref_out_tensors[i]) { | |||
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
auto shape = tensor.MutableShape(); | |||
if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { | |||
GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, | |||
shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||
ref_out_tensor_shape = GeShape(UNKNOWN_RANK); | |||
break; | |||
} | |||
for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { | |||
if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { | |||
continue; | |||
} | |||
GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, | |||
j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||
(void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); | |||
} | |||
} | |||
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors, | |||
std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) { | |||
GELOGD("Enter update parent node shape for class while op process"); | |||
if (ref_data_tensors.size() != ref_out_tensors.size()) { | |||
GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(), | |||
ref_data_tensors.size(), ref_out_tensors.size()); | |||
return GRAPH_FAILED; | |||
} | |||
for (size_t i = 0; i < ref_data_tensors.size(); i++) { | |||
if (ref_out_tensors[i].size() != 1) { | |||
GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!"); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
bool is_need_reverse_brush = false; | |||
// check input and output | |||
for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||
if (ref_out_tensors[i].empty()) { | |||
continue; | |||
} | |||
auto ref_out_tensor = ref_out_tensors[i].at(0); | |||
auto tmp_shape = ref_out_tensor.MutableShape(); | |||
// ref_i's data and output tensor shape should be same | |||
for (auto &tensor : ref_data_tensors[i]) { | |||
if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||
GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
auto shape = tensor.MutableShape(); | |||
if (shape.GetDims() != tmp_shape.GetDims()) { | |||
ref_out_tensor.SetUnknownDimNumShape(); | |||
is_need_reverse_brush = true; | |||
break; | |||
} | |||
} | |||
(void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||
} | |||
// reverse refresh while body shape | |||
if (is_need_reverse_brush) { | |||
return ReverseBrushWhileBodySubGraph(node); | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
auto op_desc = node->GetOpDesc(); | |||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
@@ -207,37 +98,6 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph, NodePtr &netoutput, | |||
const ConstNodePtr &node, | |||
std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) { | |||
auto sub_nodes = sub_graph->GetDirectNode(); | |||
for (size_t i = sub_nodes.size(); i > 0; --i) { | |||
auto sub_node = sub_nodes.at(i - 1); | |||
if (sub_node->GetType() == NETOUTPUT) { | |||
netoutput = sub_node; | |||
} | |||
if (sub_node->GetType() == DATA) { | |||
if (sub_node->GetOpDesc() == nullptr) { | |||
return GRAPH_FAILED; | |||
} | |||
int ref_i; | |||
if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) { | |||
GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(), | |||
ref_i, node->GetAllInDataAnchorsSize()); | |||
return GRAPH_FAILED; | |||
} | |||
ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
auto op_desc = node->GetOpDesc(); | |||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
@@ -245,10 +105,7 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
return GRAPH_SUCCESS; | |||
} | |||
std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize()); | |||
std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize()); | |||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
for (const auto &name : sub_graph_names) { | |||
if (name.empty()) { | |||
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||
@@ -260,9 +117,13 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
return GRAPH_FAILED; | |||
} | |||
NodePtr netoutput = nullptr; | |||
auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); | |||
if (ret != GRAPH_SUCCESS) { | |||
return ret; | |||
auto sub_nodes = sub_graph->GetDirectNode(); | |||
for (size_t i = sub_nodes.size(); i > 0; --i) { | |||
auto sub_node = sub_nodes.at(i - 1); | |||
if (sub_node->GetType() == NETOUTPUT) { | |||
netoutput = sub_node; | |||
break; | |||
} | |||
} | |||
if (netoutput == nullptr) { | |||
GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); | |||
@@ -289,17 +150,19 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
continue; | |||
} | |||
GELOGI("Parent node index of edge desc is %d", ref_i); | |||
if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) { | |||
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i)); | |||
if (output_desc == nullptr) { | |||
GE_LOGE( | |||
"The ref index(%d) on the input %d of netoutput %s on the sub graph %s " | |||
"parent node %s are incompatible, outputs num %u", | |||
ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||
node->GetAllOutDataAnchorsSize()); | |||
return GRAPH_FAILED; | |||
} | |||
ref_out_tensors[ref_i].emplace_back(*edge_desc); | |||
op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc); | |||
} | |||
} | |||
if (node->GetType() == WHILE) { | |||
return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); | |||
} | |||
return UpdateParentNodeForBranch(node, ref_out_tensors); | |||
return GRAPH_SUCCESS; | |||
} | |||
} // namespace | |||
void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { | |||
@@ -307,9 +170,6 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||
GELOGE(GRAPH_FAILED, "node is null"); | |||
return; | |||
} | |||
if (!IsLogEnable(GE, DLOG_DEBUG)) { | |||
return; | |||
} | |||
ge::OpDescPtr op_desc = node->GetOpDesc(); | |||
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); | |||
std::string str; | |||
@@ -0,0 +1,6 @@ | |||
inc_path := $(shell pwd)/inc/external/ | |||
out_path := $(shell pwd)/out/atc/lib64/stub/ | |||
stub_path := $(shell pwd)/common/graph/stub/ | |||
mkdir_stub := $(shell mkdir -p $(out_path)) | |||
graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) |
@@ -0,0 +1,573 @@ | |||
import os | |||
import re | |||
import sys | |||
import logging | |||
logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', | |||
level=logging.INFO) | |||
""" | |||
this attr is used for symbol table visible | |||
""" | |||
GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' | |||
""" | |||
generate stub func body by return type | |||
""" | |||
RETURN_STATEMENTS = { | |||
'graphStatus': ' return GRAPH_SUCCESS;', | |||
'Status': ' return SUCCESS;', | |||
'Graph': ' return Graph();', | |||
'Graph&': ' return *this;', | |||
'Format': ' return Format();', | |||
'Format&': ' return *this;', | |||
'Shape': ' return Shape();', | |||
'Shape&': ' return *this;', | |||
'TensorDesc': ' return TensorDesc();', | |||
'TensorDesc&': ' return *this;', | |||
'Tensor': ' return Tensor();', | |||
'Tensor&': ' return *this;', | |||
'Operator': ' return Operator();', | |||
'Operator&': ' return *this;', | |||
'Ptr': ' return nullptr;', | |||
'std::string': ' return "";', | |||
'std::string&': ' return "";', | |||
'string': ' return "";', | |||
'int': ' return 0;', | |||
'DataType': ' return DT_FLOAT;', | |||
'InferenceContextPtr': ' return nullptr;', | |||
'SubgraphBuilder': ' return nullptr;', | |||
'OperatorImplPtr': ' return nullptr;', | |||
'OutHandler': ' return nullptr;', | |||
'std::vector<std::string>': ' return {};', | |||
'std::vector<int64_t>': ' return {};', | |||
'std::map': ' return {};', | |||
'uint32_t': ' return 0;', | |||
'int64_t': ' return 0;', | |||
'uint64_t': ' return 0;', | |||
'size_t': ' return 0;', | |||
'float': ' return 0.0f;', | |||
'bool': ' return false;', | |||
} | |||
""" | |||
max code len per line in hua_wei software programming specifications | |||
""" | |||
max_code_len_per_line = 100 | |||
""" | |||
white_list_for_debug, include_dir_key_words is to | |||
determines which header files to generate cc files from | |||
when DEBUG on | |||
""" | |||
white_list_for_debug = ["operator.h", "tensor.h", | |||
"graph.h", "operator_factory.h", | |||
"ge_ir_build.h"] | |||
include_dir_key_words = ["ge", "graph"] | |||
DEBUG = True | |||
def need_generate_func(func_line): | |||
""" | |||
:param func_line: | |||
:return: | |||
""" | |||
if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ | |||
or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): | |||
return False | |||
return True | |||
def file_endswith_white_list_suffix(file): | |||
""" | |||
:param file: | |||
:return: | |||
""" | |||
if DEBUG: | |||
for suffix in white_list_for_debug: | |||
if file.endswith(suffix): | |||
return True | |||
return False | |||
else: | |||
return True | |||
""" | |||
belows are patterns used for analyse .h file | |||
""" | |||
# pattern function | |||
pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after | |||
([a-zA-Z~_] # void int likely | |||
.* | |||
[)] #we find ) | |||
(?!.*{) # we do not want the case int abc() const { return 1;} | |||
.*) | |||
(;.*) #we want to find ; and after for we will replace these later | |||
\n$ | |||
""", re.VERBOSE | re.MULTILINE | re.DOTALL) | |||
# pattern comment | |||
pattern_comment = re.compile(r'^\s*//') | |||
pattern_comment_2_start = re.compile(r'^\s*/[*]') | |||
pattern_comment_2_end = re.compile(r'[*]/\s*$') | |||
# pattern define | |||
pattern_define = re.compile(r'^\s*#define') | |||
pattern_define_return = re.compile(r'\\\s*$') | |||
# blank line | |||
pattern_blank_line = re.compile(r'^\s*$') | |||
# virtual,explicit,friend,static | |||
pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') | |||
# lead space | |||
pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') | |||
# functions will have patterns such as func ( or func( | |||
# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist | |||
# format like :"operator = ()" | |||
pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') | |||
# template | |||
pattern_template = re.compile(r'^\s*template') | |||
pattern_template_end = re.compile(r'>\s*$') | |||
# namespace | |||
pattern_namespace = re.compile(r'namespace.*{') | |||
# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with | |||
pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR) | |||
# {} | |||
pattern_start = re.compile('{') | |||
pattern_end = re.compile('}') | |||
line_index = 0 | |||
class H2CC(object): | |||
def __init__(self, input_file, output_file, shared_includes_content): | |||
""" | |||
:param input_file: | |||
:param output_file: | |||
:param shared_includes_content: | |||
""" | |||
self.input_file = input_file | |||
self.output_file = output_file | |||
self.shared_includes_content = shared_includes_content | |||
self.line_index = 0 | |||
self.input_fd = open(self.input_file, 'r') | |||
self.input_content = self.input_fd.readlines() | |||
self.output_fd = open(self.output_file, 'w') | |||
# The state may be normal_now(in the middle of {}),class_now,namespace_now | |||
self.stack = [] | |||
self.stack_class = [] | |||
self.stack_template = [] | |||
# record funcs generated by h2cc func | |||
self.func_list_exist = [] | |||
def __del__(self): | |||
self.input_fd.close() | |||
self.output_fd.close() | |||
del self.stack | |||
del self.stack_class | |||
del self.stack_template | |||
del self.func_list_exist | |||
def just_skip(self): | |||
# skip blank line or comment | |||
if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( | |||
self.input_content[self.line_index]): # /n or comment using // | |||
self.line_index += 1 | |||
if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /* | |||
while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */ | |||
self.line_index += 1 | |||
self.line_index += 1 | |||
# skip define | |||
if pattern_define.search(self.input_content[self.line_index]): | |||
while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search( | |||
self.input_content[self.line_index]): | |||
self.line_index += 1 | |||
self.line_index += 1 | |||
def write_inc_content(self): | |||
for shared_include_content in self.shared_includes_content: | |||
self.output_fd.write(shared_include_content) | |||
def h2cc(self): | |||
""" | |||
:return: | |||
""" | |||
logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file) | |||
global pattern_comment | |||
global pattern_comment_2_start | |||
global pattern_comment_2_end | |||
global pattern_blank_line | |||
global pattern_func | |||
global pattern_keyword | |||
global pattern_leading_space | |||
global pattern_func_name | |||
global pattern_template | |||
global pattern_template_end | |||
global pattern_namespace | |||
global pattern_class | |||
global pattern_start | |||
global pattern_end | |||
global line_index | |||
# write inc content | |||
self.write_inc_content() | |||
# core processing cycle, process the input .h file by line | |||
while self.line_index < len(self.input_content): | |||
# handle comment and blank line | |||
self.just_skip() | |||
# match namespace | |||
self.handle_namespace() | |||
# match template | |||
template_string = self.handle_template() | |||
# match class | |||
line = self.input_content[self.line_index] | |||
match_class = pattern_class.search(line) | |||
match_start = pattern_start.search(line) | |||
handle_class_result = self.handle_class(template_string, line, match_start, match_class) | |||
if handle_class_result == "continue": | |||
continue | |||
# match "}" | |||
handle_stack_result = self.handle_stack(match_start) | |||
if handle_stack_result == "continue": | |||
continue | |||
# handle func | |||
handle_func1_result, line, start_i = self.handle_func1(line) | |||
if handle_func1_result == "continue": | |||
continue | |||
# here means func is found | |||
# delete key word | |||
line = pattern_keyword.sub('', line) | |||
logging.info("line[%s]", line) | |||
# Class member function | |||
# if friend we will not add class name | |||
friend_match = re.search('friend ', line) | |||
if len(self.stack_class) > 0 and not friend_match: | |||
line, func_name = self.handle_class_member_func(line, template_string) | |||
# Normal functions | |||
else: | |||
line, func_name = self.handle_normal_func(line, template_string) | |||
need_generate = need_generate_func(line) | |||
# func body | |||
line += self.implement_function(line) | |||
# comment | |||
line = self.gen_comment(start_i) + line | |||
# write to out file | |||
self.write_func_content(line, func_name, need_generate) | |||
# next loop | |||
self.line_index += 1 | |||
logging.info('Added %s functions', len(self.func_list_exist)) | |||
logging.info('Successfully converted,please see ' + self.output_file) | |||
def handle_func1(self, line): | |||
""" | |||
:param line: | |||
:return: | |||
""" | |||
find1 = re.search('[(]', line) | |||
if not find1: | |||
self.line_index += 1 | |||
return "continue", line, None | |||
find2 = re.search('[)]', line) | |||
start_i = self.line_index | |||
space_match = pattern_leading_space.search(line) | |||
# deal with | |||
# int abc(int a, | |||
# int b) | |||
if find1 and (not find2): | |||
self.line_index += 1 | |||
line2 = self.input_content[self.line_index] | |||
if space_match: | |||
line2 = re.sub('^' + space_match.group(1), '', line2) | |||
line += line2 | |||
while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): | |||
self.line_index += 1 | |||
line2 = self.input_content[self.line_index] | |||
line2 = re.sub('^' + space_match.group(1), '', line2) | |||
line += line2 | |||
match_start = pattern_start.search(self.input_content[self.line_index]) | |||
match_end = pattern_end.search(self.input_content[self.line_index]) | |||
if match_start: # like ) { or ) {} int the last line | |||
if not match_end: | |||
self.stack.append('normal_now') | |||
ii = start_i | |||
while ii <= self.line_index: | |||
ii += 1 | |||
self.line_index += 1 | |||
return "continue", line, start_i | |||
logging.info("line[%s]", line) | |||
# ' int abc();'->'int abc()' | |||
(line, match) = pattern_func.subn(r'\2\n', line) | |||
logging.info("line[%s]", line) | |||
# deal with case: | |||
# 'int \n abc(int a, int b)' | |||
if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): | |||
line = self.input_content[start_i - 1] + line | |||
line = line.lstrip() | |||
if not match: | |||
self.line_index += 1 | |||
return "continue", line, start_i | |||
return "pass", line, start_i | |||
def handle_stack(self, match_start): | |||
""" | |||
:param match_start: | |||
:return: | |||
""" | |||
line = self.input_content[self.line_index] | |||
match_end = pattern_end.search(line) | |||
if match_start: | |||
self.stack.append('normal_now') | |||
if match_end: | |||
top_status = self.stack.pop() | |||
if top_status == 'namespace_now': | |||
self.output_fd.write(line + '\n') | |||
elif top_status == 'class_now': | |||
self.stack_class.pop() | |||
self.stack_template.pop() | |||
if match_start or match_end: | |||
self.line_index += 1 | |||
return "continue" | |||
if len(self.stack) > 0 and self.stack[-1] == 'normal_now': | |||
self.line_index += 1 | |||
return "continue" | |||
return "pass" | |||
def handle_class(self, template_string, line, match_start, match_class): | |||
""" | |||
:param template_string: | |||
:param line: | |||
:param match_start: | |||
:param match_class: | |||
:return: | |||
""" | |||
if match_class: # we face a class | |||
self.stack_template.append(template_string) | |||
self.stack.append('class_now') | |||
class_name = match_class.group(3) | |||
# class template specializations: class A<u,Node<u> > | |||
if '<' in class_name: | |||
k = line.index('<') | |||
fit = 1 | |||
for ii in range(k + 1, len(line)): | |||
if line[ii] == '<': | |||
fit += 1 | |||
if line[ii] == '>': | |||
fit -= 1 | |||
if fit == 0: | |||
break | |||
class_name += line[k + 1:ii + 1] | |||
logging.info('class_name[%s]', class_name) | |||
self.stack_class.append(class_name) | |||
while not match_start: | |||
self.line_index += 1 | |||
line = self.input_content[self.line_index] | |||
match_start = pattern_start.search(line) | |||
self.line_index += 1 | |||
return "continue" | |||
return "pass" | |||
def handle_template(self): | |||
line = self.input_content[self.line_index] | |||
match_template = pattern_template.search(line) | |||
template_string = '' | |||
if match_template: | |||
match_template_end = pattern_template_end.search(line) | |||
template_string = line | |||
while not match_template_end: | |||
self.line_index += 1 | |||
line = self.input_content[self.line_index] | |||
template_string += line | |||
match_template_end = pattern_template_end.search(line) | |||
self.line_index += 1 | |||
return template_string | |||
def handle_namespace(self): | |||
line = self.input_content[self.line_index] | |||
match_namespace = pattern_namespace.search(line) | |||
if match_namespace: # we face namespace | |||
self.output_fd.write(line + '\n') | |||
self.stack.append('namespace_now') | |||
self.line_index += 1 | |||
def handle_normal_func(self, line, template_string): | |||
template_line = '' | |||
self.stack_template.append(template_string) | |||
if self.stack_template[-1] != '': | |||
template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) | |||
# change '< class T = a, class U = A(3)>' to '<class T, class U>' | |||
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) | |||
template_line = re.sub(r'\s*=.*,', ',', template_line) | |||
template_line = re.sub(r'\s*=.*', '', template_line) | |||
line = re.sub(r'\s*=.*,', ',', line) | |||
line = re.sub(r'\s*=.*\)', ')', line) | |||
line = template_line + line | |||
self.stack_template.pop() | |||
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() | |||
logging.info("line[%s]", line) | |||
logging.info("func_name[%s]", func_name) | |||
return line, func_name | |||
def handle_class_member_func(self, line, template_string): | |||
template_line = '' | |||
x = '' | |||
if template_string != '': | |||
template_string = re.sub(r'\s*template', 'template', template_string) | |||
template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) | |||
template_string = re.sub(r'\s*=.*,', ',', template_string) | |||
template_string = re.sub(r'\s*=.*', '', template_string) | |||
if self.stack_template[-1] != '': | |||
if not (re.search(r'<\s*>', stack_template[-1])): | |||
template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) | |||
if not (re.search(r'<.*>', self.stack_class[-1])): | |||
# for x we get like template<class T, typename U> -> <T,U> | |||
x = re.sub(r'template\s*<', '<', template_line) # remove template -> <class T, typename U> | |||
x = re.sub(r'\n', '', x) | |||
x = re.sub(r'\s*=.*,', ',', x) | |||
x = re.sub(r'\s*=.*\>', '>', x) | |||
x = x.rstrip() # remove \n | |||
x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '', | |||
x) # remove class,typename -> <T, U> | |||
x = re.sub(r'<\s+', '<', x) | |||
x = re.sub(r'\s+>', '>', x) | |||
x = re.sub(r'\s+,', ',', x) | |||
x = re.sub(r',\s+', ', ', x) | |||
line = re.sub(r'\s*=\s+0', '', line) | |||
line = re.sub(r'\s*=\s+.*,', ',', line) | |||
line = re.sub(r'\s*=\s+.*\)', ')', line) | |||
logging.info("x[%s]\nline[%s]", x, line) | |||
# if the function is long, void ABC::foo() | |||
# breaks into two lines void ABC::\n foo() | |||
temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1) | |||
if len(temp_line) > max_code_len_per_line: | |||
line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1) | |||
else: | |||
line = temp_line | |||
logging.info("line[%s]", line) | |||
# add template as the above if there is one | |||
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) | |||
template_line = re.sub(r'\s*=.*,', ',', template_line) | |||
template_line = re.sub(r'\s*=.*', '', template_line) | |||
line = template_line + template_string + line | |||
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() | |||
logging.info("line[%s]", line) | |||
logging.info("func_name[%s]", func_name) | |||
return line, func_name | |||
def write_func_content(self, content, func_name, need_generate): | |||
if not (func_name in self.func_list_exist) and need_generate: | |||
self.output_fd.write(content) | |||
self.func_list_exist.append(func_name) | |||
logging.info('add func:[%s]', func_name) | |||
def gen_comment(self, start_i): | |||
comment_line = '' | |||
# Function comments are on top of function declarations, copy them over | |||
k = start_i - 1 # one line before this func start | |||
if pattern_template.search(self.input_content[k]): | |||
k -= 1 | |||
if pattern_comment_2_end.search(self.input_content[k]): | |||
comment_line = self.input_content[k].lstrip() | |||
while not pattern_comment_2_start.search(self.input_content[k]): | |||
k -= 1 | |||
comment_line = self.input_content[k].lstrip() + comment_line | |||
else: | |||
for j in range(k, 0, -1): | |||
c_line = self.input_content[j] | |||
if pattern_comment.search(c_line): | |||
c_line = re.sub(r'\s*//', '//', c_line) | |||
comment_line = c_line + comment_line | |||
else: | |||
break | |||
return comment_line | |||
@staticmethod | |||
def implement_function(func): | |||
function_def = '' | |||
function_def += '{\n' | |||
all_items = func.split() | |||
start = 0 | |||
return_type = all_items[start] | |||
if return_type == "const": | |||
start += 1 | |||
return_type = all_items[start] | |||
if return_type.startswith(('std::map', 'std::set', 'std::vector')): | |||
return_type = "std::map" | |||
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): | |||
return_type = "Ptr" | |||
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): | |||
return_type += "&" | |||
if RETURN_STATEMENTS.__contains__(return_type): | |||
function_def += RETURN_STATEMENTS[return_type] | |||
else: | |||
logging.warning("Unhandled return type[%s]", return_type) | |||
function_def += '\n' | |||
function_def += '}\n' | |||
function_def += '\n' | |||
return function_def | |||
def collect_header_files(path): | |||
""" | |||
:param path: | |||
:return: | |||
""" | |||
header_files = [] | |||
shared_includes_content = [] | |||
for root, dirs, files in os.walk(path): | |||
files.sort() | |||
for file in files: | |||
if file.find("git") >= 0: | |||
continue | |||
if not file.endswith('.h'): | |||
continue | |||
file_path = os.path.join(root, file) | |||
file_path = file_path.replace('\\', '/') | |||
header_files.append(file_path) | |||
include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) | |||
shared_includes_content.append(include_str) | |||
return header_files, shared_includes_content | |||
def generate_stub_file(inc_dir, out_cc_dir): | |||
""" | |||
:param inc_dir: | |||
:param out_cc_dir: | |||
:return: | |||
""" | |||
target_header_files, shared_includes_content = collect_header_files(inc_dir) | |||
for header_file in target_header_files: | |||
if not file_endswith_white_list_suffix(header_file): | |||
continue | |||
cc_file = re.sub('.h*$', '.cc', header_file) | |||
h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) | |||
h_2_cc.h2cc() | |||
def gen_code(inc_dir, out_cc_dir): | |||
""" | |||
:param inc_dir: | |||
:param out_cc_dir: | |||
:return: | |||
""" | |||
if not inc_dir.endswith('/'): | |||
inc_dir += '/' | |||
if not out_cc_dir.endswith('/'): | |||
out_cc_dir += '/' | |||
for include_dir_key_word in include_dir_key_words: | |||
generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) | |||
if __name__ == '__main__': | |||
inc_dir = sys.argv[1] | |||
out_cc_dir = sys.argv[2] | |||
gen_code(inc_dir, out_cc_dir) |
@@ -1,18 +1,18 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
*/ | |||
#ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ | |||
#define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ | |||
@@ -38,7 +38,6 @@ | |||
#include "utils/ge_ir_utils.h" | |||
#include "utils/node_utils.h" | |||
#include "debug/ge_op_types.h" | |||
#include "external/ge/ge_api_types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
@@ -411,8 +410,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTra | |||
/// @return graphStatus | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { | |||
GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { | |||
GE_CHECK_NOTNULL(src); | |||
GE_CHECK_NOTNULL(insert_node); | |||
@@ -571,7 +570,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons | |||
static int max_dumpfile_num = 0; | |||
if (max_dumpfile_num == 0) { | |||
string opt = "0"; | |||
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | |||
(void)GetContext().GetOption("ge.maxDumpFileNum", opt); | |||
max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||
} | |||
if (max_dumpfile_num != 0 && file_idx > max_dumpfile_num) { | |||
@@ -671,7 +670,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText | |||
if (maxDumpFileSize == 0) { | |||
string opt = "0"; | |||
// Can not check return value | |||
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); | |||
(void)GetContext().GetOption("ge.maxDumpFileSize", opt); | |||
maxDumpFileSize = atol(opt.c_str()); | |||
} | |||
if (maxDumpFileSize != 0 && fileSize != -1 && fileSize > maxDumpFileSize) { | |||
@@ -741,7 +740,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn | |||
static int max_dumpfile_num = 0; | |||
if (max_dumpfile_num == 0) { | |||
string opt = "0"; | |||
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | |||
(void)GetContext().GetOption("ge.maxDumpFileNum", opt); | |||
max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||
} | |||
if (max_dumpfile_num != 0 && file_index > max_dumpfile_num) { | |||
@@ -921,7 +920,7 @@ graphStatus RelinkDataIO(const NodePtr &node, const std::vector<int> &io_map, In | |||
InNodesToOut GetFullConnectIONodes(const NodePtr &node) { | |||
InNodesToOut in_nodes_to_out; | |||
if (node == nullptr) { | |||
GELOGE(GRAPH_FAILED, "Node is nullptr"); | |||
GELOGE(GRAPH_FAILED, "Node is nullptr,node is %s", node->GetName().c_str()); | |||
return in_nodes_to_out; | |||
} | |||
auto in_nodes_list = node->GetInNodes(); | |||
@@ -1309,36 +1308,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCt | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// Copy all in-data edges from `src_node` to `dst_node`. | |||
/// @param src_node | |||
/// @param dst_node | |||
/// @return | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node, | |||
NodePtr &dst_node) { | |||
if ((src_node == nullptr) || (dst_node == nullptr)) { | |||
GELOGE(GRAPH_FAILED, "Parameter is nullptr"); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
auto src_data_in_nodes = src_node->GetInDataNodes(); | |||
if (src_data_in_nodes.empty()) { | |||
return GRAPH_SUCCESS; | |||
} | |||
for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { | |||
auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); | |||
auto ret = | |||
GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); | |||
if (ret != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", | |||
in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), | |||
src_node->GetName().c_str(), dst_node->GetName().c_str()); | |||
return ret; | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph, | |||
const NodePtr &node) { | |||
if (graph->AddInputNode(node) == nullptr) { | |||
@@ -1370,7 +1339,7 @@ graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, | |||
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol) { | |||
GE_CHECK_NOTNULL(graph); | |||
for (const auto &node : graph->GetAllNodes()) { | |||
for (auto &node : graph->GetAllNodes()) { | |||
// in_data_anchor | |||
if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); | |||
@@ -1427,16 +1396,16 @@ graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, | |||
return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); | |||
} | |||
const std::string &type = node->GetType(); | |||
std::string type = node->GetType(); | |||
if ((type == MERGE) || (type == STREAMMERGE)) { | |||
return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); | |||
} | |||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); | |||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
if (peer_out_anchor == nullptr) { | |||
const std::string &symbol = cur_node_info.ToString(); | |||
std::string symbol = cur_node_info.ToString(); | |||
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
symbol_to_anchors[symbol] = {cur_node_info}; | |||
anchor_to_symbol[symbol] = symbol; | |||
@@ -1463,7 +1432,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, | |||
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol) { | |||
GE_CHECK_NOTNULL(node); | |||
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut); | |||
if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { | |||
continue; | |||
@@ -1477,7 +1446,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, | |||
return GRAPH_FAILED; | |||
} | |||
} else { | |||
const std::string &symbol = cur_node_info.ToString(); | |||
std::string symbol = cur_node_info.ToString(); | |||
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
symbol_to_anchors.emplace(std::make_pair(symbol, std::list<NodeIndexIO>{cur_node_info})); | |||
anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); | |||
@@ -1537,7 +1506,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
GE_CHECK_NOTNULL(node); | |||
std::vector<NodeIndexIO> exist_node_infos; | |||
std::vector<NodeIndexIO> cur_node_infos; | |||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
if (peer_out_anchor == nullptr) { | |||
std::string next_name; | |||
@@ -1560,10 +1529,10 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
size_t anchor_nums = 0; | |||
NodeIndexIO max_node_index_io(nullptr, 0, kOut); | |||
for (const auto &temp_node_info : exist_node_infos) { | |||
for (auto &temp_node_info : exist_node_infos) { | |||
auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); | |||
if (iter1 != anchor_to_symbol.end()) { | |||
const std::string &temp_symbol = iter1->second; | |||
std::string temp_symbol = iter1->second; | |||
auto iter2 = symbol_to_anchors.find(temp_symbol); | |||
if (iter2 != symbol_to_anchors.end()) { | |||
if (iter2->second.size() > anchor_nums) { | |||
@@ -1575,7 +1544,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
} | |||
std::string symbol; | |||
for (const auto &temp_node_info : exist_node_infos) { | |||
for (auto &temp_node_info : exist_node_infos) { | |||
if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != | |||
GRAPH_SUCCESS) || | |||
symbol.empty()) { | |||
@@ -1587,7 +1556,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
auto iter = symbol_to_anchors.find(symbol); | |||
if (iter != symbol_to_anchors.end()) { | |||
for (const auto &temp_node_info : cur_node_infos) { | |||
for (auto &temp_node_info : cur_node_infos) { | |||
GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); | |||
iter->second.emplace_back(temp_node_info); | |||
anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); | |||
@@ -1615,7 +1584,7 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, | |||
OpDescPtr op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(peer_out_anchor); | |||
@@ -1658,8 +1627,8 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, | |||
graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, | |||
std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||
std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol) { | |||
const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; | |||
const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; | |||
std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; | |||
std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; | |||
if (symbol1 == symbol2) { | |||
symbol = symbol1; | |||
GELOGI("no need to union."); | |||
@@ -1715,7 +1684,7 @@ graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const | |||
return GRAPH_FAILED; | |||
} | |||
const std::string &symbol = iter1->second; | |||
std::string symbol = iter1->second; | |||
auto iter2 = symbol_to_anchors.find(symbol); | |||
if (iter2 == symbol_to_anchors.end()) { | |||
GE_LOGE("symbol %s not found.", symbol.c_str()); | |||
@@ -1743,7 +1712,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t | |||
// pass-through op | |||
NodePtr node = out_data_anchor->GetOwnerNode(); | |||
const std::string &type = node->GetType(); | |||
std::string type = node->GetType(); | |||
const std::set<std::string> pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; | |||
if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { | |||
reuse_in_index = output_index; | |||
@@ -1786,7 +1755,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t | |||
uint32_t reuse_input_index = 0; | |||
if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { | |||
reuse_in_index = static_cast<int32_t>(reuse_input_index); | |||
GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index, | |||
GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index, | |||
reuse_in_index); | |||
return true; | |||
} | |||
@@ -2328,7 +2297,7 @@ void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string & | |||
return; | |||
} | |||
std::string name = node->GetName() + "_RetVal_" + std::to_string(index); | |||
std::string name = node->GetName() + "_RetVal"; | |||
OpDescPtr ret_val_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); | |||
if (ret_val_desc == nullptr) { | |||
error_code = GRAPH_FAILED; | |||
@@ -296,18 +296,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||
return GRAPH_FAILED; | |||
} | |||
for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { | |||
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | |||
ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size())); | |||
bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||
if (!is_unknown_graph) { | |||
output_tensor->SetOriginShape(output_tensor->GetShape()); | |||
output_tensor->SetOriginDataType(output_tensor->GetDataType()); | |||
} | |||
GeTensorDesc output_tensor = op_desc->GetOutputDesc(out_anchor->GetIdx()); | |||
ge::TensorUtils::SetRealDimCnt(output_tensor, static_cast<uint32_t>(output_tensor.GetShape().GetDims().size())); | |||
output_tensor.SetOriginShape(output_tensor.GetShape()); | |||
output_tensor.SetOriginDataType(output_tensor.GetDataType()); | |||
GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", | |||
node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), | |||
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), | |||
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | |||
node_ptr->GetName().c_str(), output_tensor.GetOriginShape().GetShapeSize(), | |||
TypeUtils::FormatToSerialString(output_tensor.GetOriginFormat()).c_str(), | |||
TypeUtils::DataTypeToSerialString(output_tensor.GetOriginDataType()).c_str()); | |||
(void)op_desc->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor); | |||
for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { | |||
if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { | |||
GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); | |||
@@ -319,17 +316,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||
continue; | |||
} | |||
GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", | |||
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), | |||
output_tensor->GetDataType(), output_tensor->GetOriginDataType()); | |||
peer_input_desc->SetShape(output_tensor->GetShape()); | |||
peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); | |||
peer_input_desc->SetDataType(output_tensor->GetDataType()); | |||
peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); | |||
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(), | |||
output_tensor.GetDataType(), output_tensor.GetOriginDataType()); | |||
peer_input_desc->SetShape(output_tensor.GetShape()); | |||
peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); | |||
peer_input_desc->SetDataType(output_tensor.GetDataType()); | |||
peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); | |||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
(void)output_tensor->GetShapeRange(shape_range); | |||
(void)output_tensor.GetShapeRange(shape_range); | |||
peer_input_desc->SetShapeRange(shape_range); | |||
ge::TensorUtils::SetRealDimCnt(*peer_input_desc, | |||
static_cast<uint32_t>(output_tensor->GetShape().GetDims().size())); | |||
static_cast<uint32_t>(output_tensor.GetShape().GetDims().size())); | |||
GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", | |||
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), | |||
peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); | |||
@@ -404,13 +401,10 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const | |||
graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { | |||
auto desc = node.GetOpDesc(); | |||
GE_CHECK_NOTNULL(desc); | |||
// check self | |||
is_unknow = OpShapeIsUnknown(desc); | |||
if (is_unknow) { | |||
return GRAPH_SUCCESS; | |||
} | |||
auto sub_graph_names = desc->GetSubgraphInstanceNames(); | |||
if (sub_graph_names.empty()) { | |||
is_unknow = OpShapeIsUnknown(desc); | |||
return GRAPH_SUCCESS; | |||
} else { | |||
auto owner_graph = node.GetOwnerComputeGraph(); | |||
@@ -561,53 +555,6 @@ NodePtr NodeUtils::GetParentInput(const NodePtr &node) { | |||
return peer_out_anchor->GetOwnerNode(); | |||
} | |||
/// | |||
/// @brief Check is varying_input for while node | |||
/// @param [in] node: Data node for subgraph | |||
/// @return bool | |||
/// | |||
bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { | |||
if (node == nullptr) { | |||
return false; | |||
} | |||
if (node->GetType() != DATA) { | |||
return false; // not input_node for subgraph | |||
} | |||
const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); | |||
if (parent_node == nullptr) { | |||
return false; // root graph | |||
} | |||
if (kWhileOpTypes.count(parent_node->GetType()) == 0) { | |||
return false; // not input_node for while subgraph | |||
} | |||
uint32_t index_i = 0; | |||
if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { | |||
GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); | |||
return false; | |||
} | |||
bool varying_flag = true; | |||
for (const auto &item : node->GetOutDataNodesAndAnchors()) { | |||
if (item.first->GetType() != NETOUTPUT) { | |||
continue; | |||
} | |||
OpDescPtr op_desc = item.first->GetOpDesc(); | |||
uint32_t index_o = 0; | |||
if ((op_desc == nullptr) || | |||
!AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { | |||
continue; // input for while-cond subgraph | |||
} | |||
if (index_i != index_o) { | |||
continue; // varying input for while-body subgraph | |||
} | |||
varying_flag = false; | |||
break; | |||
} | |||
return varying_flag; | |||
} | |||
/// | |||
/// @brief Get subgraph input is constant. | |||
/// @param [in] node | |||
@@ -690,86 +637,4 @@ Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// @brief Get subgraph input data node by index. | |||
/// @param [in] node | |||
/// @return Node | |||
/// | |||
vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { | |||
vector<NodePtr> in_data_node_vec; | |||
auto op_desc = node.GetOpDesc(); | |||
GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); | |||
auto subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||
if (subgraph_names.empty()) { | |||
GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); | |||
return in_data_node_vec; | |||
} | |||
auto compute_graph = node.GetOwnerComputeGraph(); | |||
for (const std::string &instance_name : subgraph_names) { | |||
auto subgraph = compute_graph->GetSubgraph(instance_name); | |||
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { | |||
int parent_index = 0; | |||
if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { | |||
(void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); | |||
if (parent_index == index) { | |||
in_data_node_vec.emplace_back(node_in_subgraph); | |||
} | |||
} | |||
} | |||
} | |||
return in_data_node_vec; | |||
} | |||
/// | |||
/// @brief Get subgraph input data node by index. | |||
/// @param [in] node | |||
/// @return Node | |||
/// | |||
vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) { | |||
vector<NodePtr> out_data_node_vec; | |||
auto op_desc = node.GetOpDesc(); | |||
GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); | |||
auto subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||
if (subgraph_names.empty()) { | |||
GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); | |||
return out_data_node_vec; | |||
} | |||
auto compute_graph = node.GetOwnerComputeGraph(); | |||
for (const std::string &instance_name : subgraph_names) { | |||
auto subgraph = compute_graph->GetSubgraph(instance_name); | |||
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { | |||
if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { | |||
out_data_node_vec.emplace_back(node_in_subgraph); | |||
} | |||
} | |||
} | |||
return out_data_node_vec; | |||
} | |||
NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) { | |||
if (node.GetInDataAnchor(index) == nullptr) { | |||
return nullptr; | |||
} | |||
if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { | |||
return nullptr; | |||
} | |||
return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); | |||
} | |||
vector<NodePtr> NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) { | |||
vector<NodePtr> out_data_nodes; | |||
auto out_data_anchor = node.GetOutDataAnchor(index); | |||
if (out_data_anchor == nullptr) { | |||
return out_data_nodes; | |||
} | |||
for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
if (peer_in_anchor == nullptr) { | |||
continue; | |||
} | |||
if (peer_in_anchor->GetOwnerNode() == nullptr) { | |||
continue; | |||
} | |||
out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); | |||
} | |||
return out_data_nodes; | |||
} | |||
} // namespace ge |
@@ -197,33 +197,24 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||
continue; | |||
} | |||
auto in_node = out_anchor->GetOwnerNode(); | |||
while (true) { | |||
if (in_node == nullptr) { | |||
break; | |||
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | |||
ret.push_back(in_node); | |||
} else if (in_node->GetType() == DATA) { | |||
const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); | |||
GE_CHK_BOOL_EXEC(graph != nullptr, continue, "Owner graph is null"); | |||
const NodePtr &parent_node = graph->GetParentNode(); | |||
if (parent_node == nullptr) { | |||
continue; // Root graph. | |||
} | |||
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | |||
ret.push_back(in_node); | |||
break; | |||
} else if (in_node->GetType() == DATA) { | |||
if (NodeUtils::IsWhileVaryingInput(in_node)) { | |||
break; | |||
} | |||
in_node = NodeUtils::GetParentInput(in_node); | |||
} else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { | |||
bool is_constant = false; | |||
(void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); | |||
if (!is_constant) { | |||
break; | |||
} | |||
// Enter node has and only has one input | |||
if (in_node->GetInDataNodes().size() != 1) { | |||
GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(), | |||
in_node->GetInDataNodes().size()); | |||
break; | |||
} | |||
in_node = in_node->GetInDataNodes().at(0); | |||
} else { | |||
break; | |||
if (kWhileOpTypes.count(parent_node->GetType()) > 0) { | |||
continue; // Subgraph of While cond or body. | |||
} | |||
NodePtr input_node = NodeUtils::GetParentInput(in_node); | |||
if ((input_node != nullptr) && ((input_node->GetType() == CONSTANT) || (input_node->GetType() == CONSTANTOP))) { | |||
ret.push_back(input_node); | |||
} | |||
} | |||
} | |||
@@ -444,27 +435,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) { | |||
vector<GeTensorPtr> ret; | |||
auto op_desc = node.GetOpDesc(); | |||
GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!"); | |||
// Place holder operator, try to get the weight from parent node | |||
// when parent node is const operator | |||
if (node.GetType() == PLACEHOLDER) { | |||
std::string parent_op; | |||
(void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op); | |||
// This if judgment is necessary because the current subgraph optimization is multithreaded | |||
// and the parent node of the PLD operation should be a stable type, such as const | |||
if (parent_op == CONSTANT || parent_op == CONSTANTOP) { | |||
NodePtr parent_node = nullptr; | |||
parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); | |||
if (parent_node != nullptr) { | |||
op_desc = parent_node->GetOpDesc(); | |||
GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str()); | |||
} | |||
} | |||
} | |||
GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return ret, "node.GetOpDesc is nullptr!"); | |||
// Const operator, take the weight directly | |||
if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) { | |||
auto weight = MutableWeights(op_desc); | |||
if (node.GetOpDesc()->GetType() == CONSTANT || (node.GetOpDesc()->GetType() == CONSTANTOP)) { | |||
auto weight = MutableWeights(node.GetOpDesc()); | |||
if (weight == nullptr) { | |||
GELOGI("const op has no weight, op name:%s", node.GetName().c_str()); | |||
return ret; | |||
@@ -19,7 +19,6 @@ | |||
#include "debug/ge_log.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "common/util/error_manager/error_manager.h" | |||
#include "graph/ge_tensor.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/type_utils.h" | |||
@@ -106,10 +105,7 @@ static graphStatus CalcElementCntByDims(const std::vector<int64_t> &dims, int64_ | |||
element_cnt = 1; | |||
for (int64_t dim : dims) { | |||
if (CheckMultiplyOverflowInt64(element_cnt, dim)) { | |||
ErrorManager::GetInstance().ATCReportErrMessage( | |||
"E19013", {"function", "var1", "var2"}, | |||
{"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)}); | |||
GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim); | |||
GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, as when multiplying %ld and %ld.", element_cnt, dim); | |||
return GRAPH_FAILED; | |||
} | |||
element_cnt *= dim; | |||
@@ -277,6 +273,7 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||
case FORMAT_FRACTAL_Z: | |||
graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt); | |||
break; | |||
case FORMAT_NC1HWC0_C04: | |||
case FORMAT_FRACTAL_NZ: | |||
case FORMAT_FRACTAL_ZZ: | |||
case FORMAT_NDHWC: | |||
@@ -288,7 +285,6 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||
case FORMAT_NDC1HWC0: | |||
case FORMAT_FRACTAL_Z_C04: | |||
case FORMAT_FRACTAL_ZN_LSTM: | |||
case FORMAT_NC1HWC0_C04: | |||
graph_status = CalcElementCntByDims(dims, element_cnt); | |||
break; | |||
default: | |||
@@ -147,8 +147,7 @@ static const std::map<std::string, Format> kStringToFormatMap = { | |||
{"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, | |||
{"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G}, | |||
{"FORMAT_RESERVED", FORMAT_RESERVED}, | |||
{"ALL", FORMAT_ALL}, | |||
{"NULL", FORMAT_NULL}}; | |||
{"ALL", FORMAT_ALL}}; | |||
static const std::map<DataType, std::string> kDataTypeToStringMap = { | |||
{DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||