| @@ -81,7 +81,7 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string | |||
| return graph.AddNode(op_desc); | |||
| } | |||
| static void CreateSimpleCondGraph(ComputeGraph &graph) { | |||
| static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePtr &switch_f) { | |||
| /******************************************************************************* | |||
| * | | |||
| * Merge | |||
| @@ -107,19 +107,21 @@ static void CreateSimpleCondGraph(ComputeGraph &graph) { | |||
| const auto data1 = CreateNode(graph, "data1", DATA, 1, 1); | |||
| data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||
| data1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||
| AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||
| const auto const0 = CreateNode(graph, "const", CONSTANT, 0, 1); | |||
| const auto const1 = CreateNode(graph, "const1", CONSTANT, 0, 1); | |||
| const0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||
| const1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | |||
| { | |||
| uint64_t const_value = 0; | |||
| uint64_t const_value = 101; | |||
| const auto op_desc = const0->GetOpDesc(); | |||
| auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||
| AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | |||
| } | |||
| { | |||
| uint64_t const_value = 1; | |||
| uint64_t const_value = 101; | |||
| const auto op_desc = const1->GetOpDesc(); | |||
| auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||
| AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | |||
| @@ -128,10 +130,10 @@ static void CreateSimpleCondGraph(ComputeGraph &graph) { | |||
| const auto less1 = CreateNode(graph, "less", ENTER, 2, 1); | |||
| const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | |||
| const auto switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | |||
| const auto switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0); | |||
| AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); | |||
| AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); | |||
| switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | |||
| switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0); | |||
| AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. | |||
| AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); | |||
| const auto add1 = CreateNode(graph, "add", ENTER, 2, 1); | |||
| const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1); | |||
| @@ -188,10 +190,6 @@ TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { | |||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||
| ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); | |||
| GraphExecutionContext graph_context; | |||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||
| graph_context.model = &hybrid_model; | |||
| uint64_t value_0 = 110; | |||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
| const std::vector<TensorValue> inputs{ in_tensor0 }; | |||
| @@ -203,33 +201,34 @@ TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { | |||
| auto input_desc = output0->GetOpDesc()->GetInputDescPtr(0); | |||
| const std::vector<ConstGeTensorDescPtr> input_descs{ input_desc }; | |||
| GraphExecutionContext graph_context; | |||
| graph_context.model = &hybrid_model; | |||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||
| SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); | |||
| ASSERT_EQ(executor.ExecuteAsync(inputs, input_descs, outputs), SUCCESS); | |||
| ASSERT_EQ(executor.Synchronize(), SUCCESS); | |||
| } | |||
| TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { | |||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| CreateSimpleCondGraph(*graph); | |||
| NodePtr switch_t = nullptr; | |||
| NodePtr switch_f = nullptr; | |||
| CreateSimpleCondGraph(*graph, switch_t, switch_f); | |||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||
| ge_root_model->SetModelName("test_name"); | |||
| GeModelPtr ge_sub_model = make_shared<GeModel>(); | |||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||
| Buffer weights_buffer(1024, 0x76); | |||
| ge_sub_model->SetWeight(weights_buffer); | |||
| std::vector<uint64_t> weights_value{101, 102}; | |||
| ge_sub_model->SetWeight(Buffer::CopyFrom((uint8_t *)weights_value.data(), weights_value.size() * sizeof(uint64_t))); | |||
| ge_sub_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | |||
| HybridModel hybrid_model(ge_root_model); | |||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||
| ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); | |||
| GraphExecutionContext graph_context; | |||
| graph_context.model = &hybrid_model; | |||
| graph_context.allocator = NpuMemoryAllocator::GetAllocator(0); | |||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||
| ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); | |||
| uint64_t value_0 = 110; | |||
| uint64_t value_0 = 101; // Enter used for Less, will pass this value to switch. | |||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
| uint64_t value_1 = 110; | |||
| TensorValue in_tensor1(&value_1, sizeof(value_1)); | |||
| @@ -242,8 +241,27 @@ TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { | |||
| TensorUtils::SetSize(*tensor_desc, 64); | |||
| const std::vector<ConstGeTensorDescPtr> input_desc{ tensor_desc, tensor_desc }; | |||
| GraphExecutionContext graph_context; | |||
| graph_context.model = &hybrid_model; | |||
| graph_context.allocator = NpuMemoryAllocator::GetAllocator(0); | |||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||
| ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); | |||
| const auto node_it_t = hybrid_model.node_items_.find(switch_t); | |||
| const auto node_it_f = hybrid_model.node_items_.find(switch_f); | |||
| ASSERT_NE(hybrid_model.node_items_.end(), node_it_t); | |||
| ASSERT_NE(hybrid_model.node_items_.end(), node_it_f); | |||
| SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); | |||
| ASSERT_EQ(executor.ExecuteAsync(inputs, input_desc, outputs), SUCCESS); | |||
| ASSERT_EQ(executor.Synchronize(), SUCCESS); | |||
| const auto state_it_t = executor.subgraph_context_->node_states_.find(node_it_t->second.get()); | |||
| const auto state_it_f = executor.subgraph_context_->node_states_.find(node_it_f->second.get()); | |||
| ASSERT_NE(executor.subgraph_context_->node_states_.end(), state_it_t); | |||
| ASSERT_NE(executor.subgraph_context_->node_states_.end(), state_it_f); | |||
| ASSERT_EQ(state_it_t->second->GetSwitchIndex(), 1); | |||
| ASSERT_EQ(state_it_f->second->GetSwitchIndex(), 0); | |||
| ASSERT_EQ(graph_context.callback_manager->Destroy(), SUCCESS); | |||
| } | |||
| } // namespace ge | |||