diff --git a/ge/graph/passes/no_use_reshape_remove_pass.cc b/ge/graph/passes/no_use_reshape_remove_pass.cc index 44f520f0..c2b8bdad 100644 --- a/ge/graph/passes/no_use_reshape_remove_pass.cc +++ b/ge/graph/passes/no_use_reshape_remove_pass.cc @@ -82,21 +82,40 @@ Status NoUseReshapeRemovePass::Run(ge::NodePtr &node) { } } if (to_be_deleted) { - GELOGI("NoUseReshapeRemovePass remove useless node:%s", node->GetName().c_str()); - // if shape_input has no any input,which means a single const, it can be unlink from reshape - // op(x) const(shape) - // \ / - // reshape - auto shape_input_anchor = node->GetInDataAnchor(kReshapeShapeIndex); - if (shape_input_anchor != nullptr) { - auto shape_input = shape_input_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(shape_input); - if (shape_input->GetInAllNodes().empty()) { - shape_input_anchor->UnlinkAll(); - } - } + auto ret = TryRemoveConstShapeInput(node); + GE_CHK_STATUS_RET_NOLOG(ret); + GELOGI("NoUseReshapeRemovePass remove useless reshape node:%s", node->GetName().c_str()); return IsolateAndDeleteNode(node, {kReshapeDataIndex}); } return SUCCESS; } + +Status NoUseReshapeRemovePass::TryRemoveConstShapeInput(ge::NodePtr &reshape_node) { + auto shape_input_anchor = reshape_node->GetInDataAnchor(kReshapeShapeIndex); + if (shape_input_anchor == nullptr) { + return SUCCESS; + } + auto shape_input = shape_input_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(shape_input); + if (shape_input->GetType() != CONSTANT && shape_input->GetType() != CONSTANTOP) { + return SUCCESS; + } + // op(x) const(shape) + // \ / + // reshape + // const input can unlink but should copy control_dependency + auto ret = PassUtils::UnlinkNodeWithControlCopy(reshape_node, kReshapeShapeIndex); + if (ret != SUCCESS) { + GELOGE(ret, "Unlink node %s with control copy failed.", shape_input->GetName().c_str()); + return ret; + } + + // remove const without any data_output + if (shape_input->GetOutDataNodesSize() == 0) { + auto ret = IsolateAndDeleteNode(shape_input, {}); + GE_CHK_GRAPH_STATUS_RET(ret, "Fail to remove node %s", shape_input->GetName().c_str()); + GELOGI("Remove useless shape input const %s.", shape_input->GetName().c_str()); + } + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/passes/no_use_reshape_remove_pass.h b/ge/graph/passes/no_use_reshape_remove_pass.h index c142d8d2..3eb6770b 100755 --- a/ge/graph/passes/no_use_reshape_remove_pass.h +++ b/ge/graph/passes/no_use_reshape_remove_pass.h @@ -32,6 +32,9 @@ class NoUseReshapeRemovePass : public BaseNodePass { /// @author /// Status Run(ge::NodePtr &node) override; + + private: + Status TryRemoveConstShapeInput(NodePtr &reshape_node); }; } // namespace ge