| @@ -27,6 +27,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const int kReshapeDataIndex = 0; | const int kReshapeDataIndex = 0; | ||||
| const char* const ATTR_NAME_FORCE_UNKNOWN_SHAPE = "_force_unknown_shape"; | |||||
| enum OpHashValue { | enum OpHashValue { | ||||
| kReshapeType = 0, | kReshapeType = 0, | ||||
| kReformatType = 1, | kReformatType = 1, | ||||
| @@ -45,6 +46,25 @@ Status ReshapeRemovePass::Run(NodePtr &node) { | |||||
| int key = kToBeDeleteOp.find(node->GetType()) == kToBeDeleteOp.end() ? kOpNoDelete : kToBeDeleteOp[node->GetType()]; | int key = kToBeDeleteOp.find(node->GetType()) == kToBeDeleteOp.end() ? kOpNoDelete : kToBeDeleteOp[node->GetType()]; | ||||
| switch (key) { | switch (key) { | ||||
| case kReshapeType: { | case kReshapeType: { | ||||
| bool is_in_unknown_shape_graph = false; | |||||
| bool forced_unknown = false; | |||||
| for (const auto &node : node->GetOwnerComputeGraph()->GetDirectNode()) { | |||||
| GE_CHK_GRAPH_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_in_unknown_shape_graph), | |||||
| "[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str()); | |||||
| if (is_in_unknown_shape_graph) { | |||||
| break; | |||||
| } | |||||
| if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, forced_unknown) && forced_unknown) { | |||||
| GELOGD("node %s was marked as unknown shape.", node->GetName().c_str()); | |||||
| is_in_unknown_shape_graph = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (is_in_unknown_shape_graph) { | |||||
| GELOGI("op:%s is in unknown shape graph, can not be deleted.", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| bool is_shape_unknown = false; | bool is_shape_unknown = false; | ||||
| if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) { | if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) { | ||||
| if (is_shape_unknown) { | if (is_shape_unknown) { | ||||