Browse Source

!483 Bugfix: fix mem_rw conflict unnecessary identity in PartitionCall subgraph

From: @hugo1
Reviewed-by: @xchu42,@ji_chen
Signed-off-by: @ji_chen
tags/v1.1.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
111feee102
1 changed files with 4 additions and 2 deletions
  1. +4
    -2
      ge/graph/optimize/mem_rw_conflict_optimize.cc

+ 4
- 2
ge/graph/optimize/mem_rw_conflict_optimize.cc View File

@@ -335,7 +335,8 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) {
auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx()); auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx());
GELOGD("Output rw type of Node %s %dth output anchor is %s", pre_node->GetName().c_str(), GELOGD("Output rw type of Node %s %dth output anchor is %s", pre_node->GetName().c_str(),
pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str()); pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str());
if (pre_output_rw_type == OutputRWType::kWriteable) {
auto parent_node = sub_graph->GetParentNode();
if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) {
// insert identity // insert identity
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
GE_CHECK_NOTNULL(identity_node); GE_CHECK_NOTNULL(identity_node);
@@ -346,8 +347,9 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) {
} }
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
pre_node->GetName().c_str(), node->GetName().c_str()); pre_node->GetName().c_str(), node->GetName().c_str());
pre_output_rw_type = OutputRWType::kSoftRead;
} }
output_rw_type_map.emplace(std::make_pair(in_data_anchor->GetIdx(), OutputRWType::kSoftRead));
output_rw_type_map.emplace(std::make_pair(in_data_anchor->GetIdx(), pre_output_rw_type));
} }
NodeInputOutputRWType output_rw_type{{}, output_rw_type_map}; NodeInputOutputRWType output_rw_type{{}, output_rw_type_map};
node_rwtype_map_.emplace(std::make_pair(node->GetName(), output_rw_type)); node_rwtype_map_.emplace(std::make_pair(node->GetName(), output_rw_type));


Loading…
Cancel
Save