Browse Source

!206 rm redundant MemcpyAsync node before Merge

From: @chen_yemeng
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
33438b9daf
2 changed files with 31 additions and 0 deletions
  1. +30
    -0
      ge/graph/passes/merge_pass.cc
  2. +1
    -0
      ge/graph/passes/merge_pass.h

+ 30
- 0
ge/graph/passes/merge_pass.cc View File

@@ -79,6 +79,13 @@ Status MergePass::Run(NodePtr &node) {
return FAILED; return FAILED;
} }
} }
auto in_node = in_data_nodes.at(0);
if (IsMergeInputNeedOptimized(in_node)) {
if (IsolateAndDeleteNode(in_node, {0}) != SUCCESS) {
GELOGE(FAILED, "Isolate and delete node %s failed.", in_node->GetName().c_str());
return FAILED;
}
}
return IsolateAndDeleteNode(node, merge_io_map); return IsolateAndDeleteNode(node, merge_io_map);
} }
default: { default: {
@@ -172,4 +179,27 @@ Status MergePass::CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &
GE_CHK_STATUS_RET(op_desc->AddOutputDesc(original_out_tensor_desc), "add out put desc failed"); GE_CHK_STATUS_RET(op_desc->AddOutputDesc(original_out_tensor_desc), "add out put desc failed");
return SUCCESS; return SUCCESS;
} }

bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const {
if (node == nullptr) {
return false;
}
// node is not inserted by MergeInputMemcpyPass
if ((node->GetType() != MEMCPYASYNC) && (node->GetType() != MEMCPYADDRASYNC)) {
return false;
}
if (node->GetInDataNodes().size() != 1) {
return false;
}

auto in_node = node->GetInDataNodes().at(0);
if (in_node == nullptr) {
return false;
}
// in_node may be global_step var
if ((in_node->GetType() == VARIABLE) || (in_node->GetType() == VARIABLEV2)) {
return false;
}
return true;
}
} // namespace ge } // namespace ge

+ 1
- 0
ge/graph/passes/merge_pass.h View File

@@ -28,6 +28,7 @@ class MergePass : public BaseNodePass {
bool IsNeedChangeIndexToConstant(NodePtr &node) const; bool IsNeedChangeIndexToConstant(NodePtr &node) const;
Status ChangeIndexToConstant(NodePtr &node, int &value_index); Status ChangeIndexToConstant(NodePtr &node, int &value_index);
Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc);
bool IsMergeInputNeedOptimized(NodePtr &node) const;
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_PASSES_MERGE_PASS_H_ #endif // GE_GRAPH_PASSES_MERGE_PASS_H_

Loading…
Cancel
Save