From afb7bdd6b8d1f51b3dcee02eb1644a4cf4a191d7 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Wed, 2 Dec 2020 11:45:55 +0800 Subject: [PATCH 1/3] modified: ge/graph/optimize/mem_rw_conflict_optimize.cc --- ge/graph/optimize/mem_rw_conflict_optimize.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index e0b4b52c..2fabc035 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -335,7 +335,8 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) { auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx()); GELOGD("Output rw type of Node %s %dth output anchor is %s", pre_node->GetName().c_str(), pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str()); - if (pre_output_rw_type == OutputRWType::kWriteable) { + auto parent_node = sub_graph->GetParentNode(); + if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) { // insert identity auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); GE_CHECK_NOTNULL(identity_node); @@ -346,8 +347,9 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) { } GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), pre_node->GetName().c_str(), node->GetName().c_str()); + pre_output_rw_type = OutputRWType::kSoftRead; } - output_rw_type_map.emplace(std::make_pair(in_data_anchor->GetIdx(), OutputRWType::kSoftRead)); + output_rw_type_map.emplace(std::make_pair(in_data_anchor->GetIdx(), pre_output_rw_type)); } NodeInputOutputRWType output_rw_type{{}, output_rw_type_map}; node_rwtype_map_.emplace(std::make_pair(node->GetName(), output_rw_type)); From 5f9e80b73e4f29e2fac5914896336d23407c597c Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Wed, 2 Dec 2020 13:45:25 +0800 Subject: [PATCH 2/3] modified: ge/graph/optimize/mem_rw_conflict_optimize.cc --- ge/graph/optimize/mem_rw_conflict_optimize.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index 2fabc035..2b607fdd 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -234,7 +234,7 @@ InputRWType GetSingleNodeInputRWTypeByIndex(const Node &node, uint32_t index) { return InputRWType::kInvalidRWType; } if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMALLGATHER - || op_desc->GetType() == HCOMREDUCESCATTER || op_desc->GetType() == HCOMREDUCE) { + || op_desc->GetType() == HCOMREDUCESCATTER) { return InputRWType::kScopeWriteable; } // check if it is ref input From c599e8a9d0d0af9ab9243a3562f233198b3e5f07 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Wed, 2 Dec 2020 14:15:11 +0800 Subject: [PATCH 3/3] modified: ge/graph/optimize/mem_rw_conflict_optimize.cc --- ge/graph/optimize/mem_rw_conflict_optimize.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index 2b607fdd..2fabc035 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -234,7 +234,7 @@ InputRWType GetSingleNodeInputRWTypeByIndex(const Node &node, uint32_t index) { return InputRWType::kInvalidRWType; } if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMALLGATHER - || op_desc->GetType() == HCOMREDUCESCATTER) { + || op_desc->GetType() == HCOMREDUCESCATTER || op_desc->GetType() == HCOMREDUCE) { return InputRWType::kScopeWriteable; } // check if it is ref input