From 634cc2ef06becea63fc7ee2d03834d287d24ce18 Mon Sep 17 00:00:00 2001 From: chenyemeng Date: Thu, 5 Nov 2020 18:51:59 +0800 Subject: [PATCH] rm redundant Memcpy before Merge --- ge/graph/passes/merge_pass.cc | 28 ++++++++++++++++++++++++---- ge/graph/passes/merge_pass.h | 1 + 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/ge/graph/passes/merge_pass.cc b/ge/graph/passes/merge_pass.cc index 8938b47e..d2340037 100644 --- a/ge/graph/passes/merge_pass.cc +++ b/ge/graph/passes/merge_pass.cc @@ -80,10 +80,7 @@ Status MergePass::Run(NodePtr &node) { } } auto in_node = in_data_nodes.at(0); - bool memcpy_optimize_flag = (in_node != nullptr) && - ((in_node->GetType() == MEMCPYASYNC) || (in_node->GetType() == MEMCPYADDRASYNC)) && - (in_node->GetInDataNodes().size() == 1); - if (memcpy_optimize_flag) { + if (IsMergeInputNeedOptimized(in_node)) { if (IsolateAndDeleteNode(in_node, {0}) != SUCCESS) { GELOGE(FAILED, "Isolate and delete node %s failed.", in_node->GetName().c_str()); return FAILED; @@ -182,4 +179,27 @@ Status MergePass::CreateConstByValue(NodePtr &node, int value_index, OpDescPtr & GE_CHK_STATUS_RET(op_desc->AddOutputDesc(original_out_tensor_desc), "add out put desc failed"); return SUCCESS; } + +bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const { + if (node == nullptr) { + return false; + } + // node is not inserted by MergeInputMemcpyPass + if ((node->GetType() != MEMCPYASYNC) && (node->GetType() != MEMCPYADDRASYNC)) { + return false; + } + if (node->GetInDataNodes().size() != 1) { + return false; + } + + auto in_node = node->GetInDataNodes().at(0); + if (in_node == nullptr) { + return false; + } + // in_node may be global_step var + if ((in_node->GetType() == VARIABLE) || (in_node->GetType() == VARIABLEV2)) { + return false; + } + return true; +} } // namespace ge diff --git a/ge/graph/passes/merge_pass.h b/ge/graph/passes/merge_pass.h index 53582ff6..2cdb5022 100755 --- a/ge/graph/passes/merge_pass.h +++ b/ge/graph/passes/merge_pass.h @@ -28,6 +28,7 @@ class MergePass : public BaseNodePass { bool IsNeedChangeIndexToConstant(NodePtr &node) const; Status ChangeIndexToConstant(NodePtr &node, int &value_index); Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); + bool IsMergeInputNeedOptimized(NodePtr &node) const; }; } // namespace ge #endif // GE_GRAPH_PASSES_MERGE_PASS_H_