diff --git a/ge/graph/build/memory/graph_mem_assigner.cc b/ge/graph/build/memory/graph_mem_assigner.cc index 8becd90e..b098a5f5 100755 --- a/ge/graph/build/memory/graph_mem_assigner.cc +++ b/ge/graph/build/memory/graph_mem_assigner.cc @@ -69,6 +69,10 @@ int64_t GetSymbolOutputOffset(const std::map &anchor_t } return ge::kInvalidOffset; } + +bool isVariableMemoryNode(const ge::NodePtr &node) { + return (node->GetType() == ge::VARIABLE) || (node->GetType() == ge::CONSTANTOP); +} } // namespace namespace ge { Status VariableMemoryAssigner::Assign() { @@ -447,22 +451,31 @@ bool IsContinuousInputConflict(const ge::NodePtr &node, const OpDescPtr &peer_op /// op1 -> node -> op2 /// return true when node is ref from input, and op1 or op2 is reuse input from output bool GraphMemoryAssigner::IsRefFromInputOpCascade(const NodePtr &node) { - bool ref_from_input = false; + std::unordered_set ref_input_index; int32_t reuse_in_index = -1; for (const auto &out_anchor : node->GetAllOutDataAnchors()) { - ref_from_input = GraphUtils::IsRefFromInput(out_anchor, reuse_in_index); - if (ref_from_input) { + bool reuse_input = GraphUtils::IsRefFromInput(out_anchor, reuse_in_index); + if (reuse_input) { GELOGD("IsRefFromInputOpCascade: cur node:%s:%d is ref", node->GetName().c_str(), reuse_in_index); - break; + ref_input_index.insert(reuse_in_index); } } + bool ref_from_input = !ref_input_index.empty(); + if (!ref_from_input) { + return false; + } for (const auto &in_anchor : node->GetAllInDataAnchors()) { const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + auto in_node = peer_out_anchor->GetOwnerNode(); + if (isVariableMemoryNode(in_node) && (ref_input_index.count(in_anchor->GetIdx()) > 0)) { + GELOGD("Reuse variable memory, input node:%s, type:%s.", in_node->GetName().c_str(), in_node->GetType().c_str()); + return false; + } if (ref_from_input && GraphUtils::IsRefFromInput(peer_out_anchor, reuse_in_index)) { GELOGD("IsRefFromInputOpCascade: in node[%s] is ref, reuse index is:%d", - peer_out_anchor->GetOwnerNode()->GetName().c_str(), reuse_in_index); + in_node->GetName().c_str(), reuse_in_index); return true; } } @@ -500,6 +513,11 @@ Status GraphMemoryAssigner::UpdateRefOpOffsetReverse(const NodePtr &node) { GE_CHECK_NOTNULL(peer_out_anchor); auto peer_node = peer_out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(peer_node); + if (isVariableMemoryNode(peer_node)) { + GELOGW("Peer node to update is %s, skip it. Node name:%s.", + peer_node->GetType().c_str(), peer_node->GetName().c_str()); + continue; + } auto peer_op_desc = peer_node->GetOpDesc(); GE_CHECK_NOTNULL(peer_op_desc); vector peer_output_list = peer_op_desc->GetOutputOffset(); diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 6959c6b3..5b5f24a2 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -613,7 +613,7 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & single_op.num_inputs_ = data_ops_.size(); single_op.num_outputs_ = netoutput_op_->GetAllInputsSize(); GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource)); - model_params_.memory_size = UINT_MAX; + model_params_.memory_size = UINT64_MAX; model_params_.graph_is_dynamic = true; auto ge_model = model_helper_.GetGeModel(); diff --git a/tests/ut/ge/graph/build/mem_assigner_unittest.cc b/tests/ut/ge/graph/build/mem_assigner_unittest.cc index c9b0b579..785af2ef 100644 --- a/tests/ut/ge/graph/build/mem_assigner_unittest.cc +++ b/tests/ut/ge/graph/build/mem_assigner_unittest.cc @@ -525,6 +525,34 @@ TEST_F(UtestMemoryAssignerTest, graph_memory_assign_update_ref_op_offset_reverse EXPECT_EQ(memoryAssigner.UpdateRefOpOffsetReverse(add), SUCCESS); } +TEST_F(UtestMemoryAssignerTest, graph_memory_assign_var_input_ref_cascade_false) { + ge::ut::GraphBuilder builder("graph"); + auto var = builder.AddNode("var", VARIABLE, 1, 1); + auto broadcast = builder.AddNode("broadcast", HCOMBROADCAST, 1, 1); + auto assign = builder.AddNode("assign", "Assign", 2, 1); + // add link + builder.AddDataEdge(var, 0, assign, 0); + builder.AddDataEdge(var, 0, broadcast, 0); + builder.AddDataEdge(broadcast, 0, assign, 1); + + int reuse_input_index = 0; + auto broadcast_desc = broadcast->GetOpDesc()->MutableOutputDesc(0); + ge::TensorUtils::SetReuseInput(*broadcast_desc, true); + ge::TensorUtils::SetReuseInputIndex(*broadcast_desc, reuse_input_index); + + ge::ComputeGraphPtr graph = builder.GetGraph(); + + GraphMemoryAssigner memory_assigner(graph); + bool ref_cascade = memory_assigner.IsRefFromInputOpCascade(broadcast); + EXPECT_EQ(ref_cascade, false); + ref_cascade = memory_assigner.IsRefFromInputOpCascade(assign); + EXPECT_EQ(ref_cascade, false); + auto ret = memory_assigner.UpdateRefOpOffsetReverse(broadcast); + EXPECT_EQ(ret, SUCCESS); + ret = memory_assigner.UpdateRefOpOffsetReverse(assign); + EXPECT_EQ(ret, SUCCESS); +} + TEST_F(UtestMemoryAssignerTest, graph_memory_assign_atomic_output_and_workspace) { ge::ut::GraphBuilder builder("graph"); auto data_input = builder.AddNode("data", "Data", 1, 1);