From: @yangyongqiang5033 Reviewed-by: @wqtshg,@ni100die,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
@@ -486,6 +486,15 @@ Status GraphMemoryAssigner::UpdateRefOpOffsetReverse(const NodePtr &node) { | |||||
auto peer_op_desc = peer_node->GetOpDesc(); | auto peer_op_desc = peer_node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(peer_op_desc); | GE_CHECK_NOTNULL(peer_op_desc); | ||||
vector<int64_t> peer_output_list = peer_op_desc->GetOutputOffset(); | vector<int64_t> peer_output_list = peer_op_desc->GetOutputOffset(); | ||||
if ((peer_out_anchor->GetIdx() >= static_cast<int>(peer_output_list.size())) | |||||
|| (out2in.first >= static_cast<int32_t>(output_list.size()))) { | |||||
GELOGW("out of range, peer_out_anchor:%d, peer_output_list size:%zu, out2in:%d, output_list size:%zu", | |||||
peer_out_anchor->GetIdx(), | |||||
peer_output_list.size(), | |||||
out2in.first, | |||||
output_list.size()); | |||||
continue; | |||||
} | |||||
peer_output_list.at(peer_out_anchor->GetIdx()) = output_list.at(out2in.first); | peer_output_list.at(peer_out_anchor->GetIdx()) = output_list.at(out2in.first); | ||||
peer_op_desc->SetOutputOffset(peer_output_list); | peer_op_desc->SetOutputOffset(peer_output_list); | ||||
GELOGD("UpdateRefOpOffsetReverse: Node[%s] output[%d] is set from node[%s] output index[%d] offset[%ld]", | GELOGD("UpdateRefOpOffsetReverse: Node[%s] output[%d] is set from node[%s] output index[%d] offset[%ld]", | ||||
@@ -339,3 +339,25 @@ TEST_F(UtestMemoryAssignerTest, graph_memory_assign_set_input_offset) { | |||||
EXPECT_EQ(assgin->GetOpDesc()->GetInputOffset()[1], 0); | EXPECT_EQ(assgin->GetOpDesc()->GetInputOffset()[1], 0); | ||||
EXPECT_EQ(memoryAssigner.CheckOffset(), GRAPH_SUCCESS); | EXPECT_EQ(memoryAssigner.CheckOffset(), GRAPH_SUCCESS); | ||||
} | } | ||||
TEST_F(UtestMemoryAssignerTest, graph_memory_assign_update_ref_op_offset_reverse) { | |||||
ge::ut::GraphBuilder builder("graph"); | |||||
auto data_input = builder.AddNode("data", "Data", 1, 1); | |||||
auto const_input = builder.AddNode("const", "Const", 1, 1); | |||||
auto add = builder.AddNode("add", "Add", 2, 1); | |||||
// add link | |||||
builder.AddDataEdge(data_input, 0, add, 0); | |||||
builder.AddDataEdge(const_input, 0, add, 1); | |||||
// set ref | |||||
uint32_t reuse_input_index = 0; | |||||
auto output_tensordesc = data_input->GetOpDesc()->MutableOutputDesc(0); | |||||
ge::TensorUtils::SetReuseInput(*output_tensordesc, true); | |||||
ge::TensorUtils::SetReuseInputIndex(*output_tensordesc, reuse_input_index); | |||||
auto output_tensordesc1 = add->GetOpDesc()->MutableOutputDesc(0); | |||||
ge::TensorUtils::SetReuseInput(*output_tensordesc1, true); | |||||
ge::TensorUtils::SetReuseInputIndex(*output_tensordesc1, reuse_input_index); | |||||
ge::ComputeGraphPtr graph = builder.GetGraph(); | |||||
GraphMemoryAssigner memoryAssigner(graph); | |||||
EXPECT_EQ(memoryAssigner.UpdateRefOpOffsetReverse(add), SUCCESS); | |||||
} |