diff --git a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_builder.cc b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_builder.cc index 33aa407d..9dc7d6a1 100644 --- a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_builder.cc +++ b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_builder.cc @@ -22,9 +22,11 @@ #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" +#include "graph/compute_graph.h" #include "ge_local_engine/ops_kernel_store/op/op_factory.h" #include "ge_local_engine/common/constant/constant.h" #include "register/ops_kernel_builder_registry.h" +#include "framework/common/debug/log.h" namespace ge { namespace ge_local { @@ -34,6 +36,7 @@ namespace { const char *const kConstantOpType = "Constant"; const char *const kConstantOpAttrName = "value"; const char *const kDataOpType = "Data"; +const char *const ATTR_NAME_FORCE_UNKNOWN_SHAPE = "_force_unknown_shape"; } // namespace GeLocalOpsKernelBuilder::~GeLocalOpsKernelBuilder() { @@ -161,13 +164,24 @@ Status GeLocalOpsKernelBuilder::CalcConstantStrMemSize(const OpDescPtr &op_desc, } Status GeLocalOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, std::vector &tasks) { - bool is_shape_unknown = false; - if (NodeUtils::GetNodeUnknownShapeStatus(node, is_shape_unknown) == GRAPH_SUCCESS) { - if (is_shape_unknown) { - GELOGI("op:%s is unknown shape, does not need to generate task", - node.GetName().c_str()); - return SUCCESS; + bool is_in_unknown_subgraph = false; + bool forced_unknown = false; + for (const auto &node : node.GetOwnerComputeGraph()->GetDirectNode()) { + GE_CHK_GRAPH_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_in_unknown_subgraph), + "[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str()); + if (is_in_unknown_subgraph) { + break; } + if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, forced_unknown) && forced_unknown) { + GELOGD("node %s was marked as unknown shape.", node->GetName().c_str()); + is_in_unknown_subgraph = true; + break; + } + } + if (is_in_unknown_subgraph) { + GELOGI("op:%s is in unknown shape subgraph, does not need to generate task", + node.GetName().c_str()); + return SUCCESS; } string name = node.GetName(); string type = node.GetType(); diff --git a/ge/graph/passes/reshape_remove_pass.cc b/ge/graph/passes/reshape_remove_pass.cc index 10937cf1..79ab88b5 100755 --- a/ge/graph/passes/reshape_remove_pass.cc +++ b/ge/graph/passes/reshape_remove_pass.cc @@ -27,6 +27,7 @@ namespace ge { namespace { const int kReshapeDataIndex = 0; +const char* const ATTR_NAME_FORCE_UNKNOWN_SHAPE = "_force_unknown_shape"; enum OpHashValue { kReshapeType = 0, kReformatType = 1, @@ -45,6 +46,25 @@ Status ReshapeRemovePass::Run(NodePtr &node) { int key = kToBeDeleteOp.find(node->GetType()) == kToBeDeleteOp.end() ? kOpNoDelete : kToBeDeleteOp[node->GetType()]; switch (key) { case kReshapeType: { + bool is_in_unknown_shape_graph = false; + bool forced_unknown = false; + for (const auto &node : node->GetOwnerComputeGraph()->GetDirectNode()) { + GE_CHK_GRAPH_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_in_unknown_shape_graph), + "[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str()); + if (is_in_unknown_shape_graph) { + break; + } + if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, forced_unknown) && forced_unknown) { + GELOGD("node %s was marked as unknown shape.", node->GetName().c_str()); + is_in_unknown_shape_graph = true; + break; + } + } + if (is_in_unknown_shape_graph) { + GELOGI("op:%s is in unknown shape graph, can not be deleted.", node->GetName().c_str()); + return SUCCESS; + } + bool is_shape_unknown = false; if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) { if (is_shape_unknown) {