From fd4653e3c25e111ba2129a5b5caa3d065630e94e Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Sun, 25 Apr 2021 21:19:03 +0800 Subject: [PATCH] add ut002 --- .../executor/subgraph_executor_unittest.cc | 60 ++++++++++++------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc index 5e9aa0e8..fbda3776 100644 --- a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc @@ -81,7 +81,7 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string return graph.AddNode(op_desc); } -static void CreateSimpleCondGraph(ComputeGraph &graph) { +static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePtr &switch_f) { /******************************************************************************* * | * Merge @@ -107,19 +107,21 @@ static void CreateSimpleCondGraph(ComputeGraph &graph) { const auto data1 = CreateNode(graph, "data1", DATA, 1, 1); data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); data1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); + AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0); + AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); const auto const0 = CreateNode(graph, "const", CONSTANT, 0, 1); const auto const1 = CreateNode(graph, "const1", CONSTANT, 0, 1); const0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); const1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); { - uint64_t const_value = 0; + uint64_t const_value = 101; const auto op_desc = const0->GetOpDesc(); auto weight = make_shared(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); } { - uint64_t const_value = 1; + uint64_t const_value = 101; const auto op_desc = const1->GetOpDesc(); auto weight = make_shared(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); @@ -128,10 +130,10 @@ static void CreateSimpleCondGraph(ComputeGraph &graph) { const auto less1 = CreateNode(graph, "less", ENTER, 2, 1); const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); - const auto switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); - const auto switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0); - AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); - AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); + switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); + switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0); + AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. + AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); const auto add1 = CreateNode(graph, "add", ENTER, 2, 1); const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1); @@ -188,10 +190,6 @@ TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { HybridModelBuilder hybrid_model_builder(hybrid_model); ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); - GraphExecutionContext graph_context; - graph_context.callback_manager = std::unique_ptr(new CallbackManager()); - graph_context.model = &hybrid_model; - uint64_t value_0 = 110; TensorValue in_tensor0(&value_0, sizeof(value_0)); const std::vector inputs{ in_tensor0 }; @@ -203,33 +201,34 @@ TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { auto input_desc = output0->GetOpDesc()->GetInputDescPtr(0); const std::vector input_descs{ input_desc }; + GraphExecutionContext graph_context; + graph_context.model = &hybrid_model; + graph_context.callback_manager = std::unique_ptr(new CallbackManager()); + SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); ASSERT_EQ(executor.ExecuteAsync(inputs, input_descs, outputs), SUCCESS); + ASSERT_EQ(executor.Synchronize(), SUCCESS); } TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { ComputeGraphPtr graph = std::make_shared("test"); - CreateSimpleCondGraph(*graph); + NodePtr switch_t = nullptr; + NodePtr switch_f = nullptr; + CreateSimpleCondGraph(*graph, switch_t, switch_f); GeRootModelPtr ge_root_model = make_shared(graph); ge_root_model->SetModelName("test_name"); GeModelPtr ge_sub_model = make_shared(); ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); - Buffer weights_buffer(1024, 0x76); - ge_sub_model->SetWeight(weights_buffer); + std::vector weights_value{101, 102}; + ge_sub_model->SetWeight(Buffer::CopyFrom((uint8_t *)weights_value.data(), weights_value.size() * sizeof(uint64_t))); ge_sub_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); HybridModel hybrid_model(ge_root_model); HybridModelBuilder hybrid_model_builder(hybrid_model); ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); - GraphExecutionContext graph_context; - graph_context.model = &hybrid_model; - graph_context.allocator = NpuMemoryAllocator::GetAllocator(0); - graph_context.callback_manager = std::unique_ptr(new CallbackManager()); - ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); - - uint64_t value_0 = 110; + uint64_t value_0 = 101; // Enter used for Less, will pass this value to switch. TensorValue in_tensor0(&value_0, sizeof(value_0)); uint64_t value_1 = 110; TensorValue in_tensor1(&value_1, sizeof(value_1)); @@ -242,8 +241,27 @@ TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { TensorUtils::SetSize(*tensor_desc, 64); const std::vector input_desc{ tensor_desc, tensor_desc }; + GraphExecutionContext graph_context; + graph_context.model = &hybrid_model; + graph_context.allocator = NpuMemoryAllocator::GetAllocator(0); + graph_context.callback_manager = std::unique_ptr(new CallbackManager()); + ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); + + const auto node_it_t = hybrid_model.node_items_.find(switch_t); + const auto node_it_f = hybrid_model.node_items_.find(switch_f); + ASSERT_NE(hybrid_model.node_items_.end(), node_it_t); + ASSERT_NE(hybrid_model.node_items_.end(), node_it_f); + SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); ASSERT_EQ(executor.ExecuteAsync(inputs, input_desc, outputs), SUCCESS); + ASSERT_EQ(executor.Synchronize(), SUCCESS); + + const auto state_it_t = executor.subgraph_context_->node_states_.find(node_it_t->second.get()); + const auto state_it_f = executor.subgraph_context_->node_states_.find(node_it_f->second.get()); + ASSERT_NE(executor.subgraph_context_->node_states_.end(), state_it_t); + ASSERT_NE(executor.subgraph_context_->node_states_.end(), state_it_f); + ASSERT_EQ(state_it_t->second->GetSwitchIndex(), 1); + ASSERT_EQ(state_it_f->second->GetSwitchIndex(), 0); ASSERT_EQ(graph_context.callback_manager->Destroy(), SUCCESS); } } // namespace ge \ No newline at end of file