|
|
@@ -81,7 +81,7 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string |
|
|
|
return graph.AddNode(op_desc); |
|
|
|
} |
|
|
|
|
|
|
|
static void CreateSimpleCondGraph(ComputeGraph &graph) { |
|
|
|
static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePtr &switch_f) { |
|
|
|
/******************************************************************************* |
|
|
|
* | |
|
|
|
* Merge |
|
|
@@ -107,19 +107,21 @@ static void CreateSimpleCondGraph(ComputeGraph &graph) { |
|
|
|
const auto data1 = CreateNode(graph, "data1", DATA, 1, 1); |
|
|
|
data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); |
|
|
|
data1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); |
|
|
|
AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0); |
|
|
|
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); |
|
|
|
|
|
|
|
const auto const0 = CreateNode(graph, "const", CONSTANT, 0, 1); |
|
|
|
const auto const1 = CreateNode(graph, "const1", CONSTANT, 0, 1); |
|
|
|
const0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); |
|
|
|
const1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); |
|
|
|
{ |
|
|
|
uint64_t const_value = 0; |
|
|
|
uint64_t const_value = 101; |
|
|
|
const auto op_desc = const0->GetOpDesc(); |
|
|
|
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); |
|
|
|
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); |
|
|
|
} |
|
|
|
{ |
|
|
|
uint64_t const_value = 1; |
|
|
|
uint64_t const_value = 101; |
|
|
|
const auto op_desc = const1->GetOpDesc(); |
|
|
|
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); |
|
|
|
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); |
|
|
@@ -128,10 +130,10 @@ static void CreateSimpleCondGraph(ComputeGraph &graph) { |
|
|
|
const auto less1 = CreateNode(graph, "less", ENTER, 2, 1); |
|
|
|
|
|
|
|
const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); |
|
|
|
const auto switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); |
|
|
|
const auto switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0); |
|
|
|
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); |
|
|
|
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1); |
|
|
|
switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); |
|
|
|
switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0); |
|
|
|
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. |
|
|
|
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); |
|
|
|
|
|
|
|
const auto add1 = CreateNode(graph, "add", ENTER, 2, 1); |
|
|
|
const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1); |
|
|
@@ -188,10 +190,6 @@ TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { |
|
|
|
HybridModelBuilder hybrid_model_builder(hybrid_model); |
|
|
|
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); |
|
|
|
|
|
|
|
GraphExecutionContext graph_context; |
|
|
|
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); |
|
|
|
graph_context.model = &hybrid_model; |
|
|
|
|
|
|
|
uint64_t value_0 = 110; |
|
|
|
TensorValue in_tensor0(&value_0, sizeof(value_0)); |
|
|
|
const std::vector<TensorValue> inputs{ in_tensor0 }; |
|
|
@@ -203,33 +201,34 @@ TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { |
|
|
|
auto input_desc = output0->GetOpDesc()->GetInputDescPtr(0); |
|
|
|
const std::vector<ConstGeTensorDescPtr> input_descs{ input_desc }; |
|
|
|
|
|
|
|
GraphExecutionContext graph_context; |
|
|
|
graph_context.model = &hybrid_model; |
|
|
|
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); |
|
|
|
|
|
|
|
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); |
|
|
|
ASSERT_EQ(executor.ExecuteAsync(inputs, input_descs, outputs), SUCCESS); |
|
|
|
ASSERT_EQ(executor.Synchronize(), SUCCESS); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { |
|
|
|
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); |
|
|
|
CreateSimpleCondGraph(*graph); |
|
|
|
NodePtr switch_t = nullptr; |
|
|
|
NodePtr switch_f = nullptr; |
|
|
|
CreateSimpleCondGraph(*graph, switch_t, switch_f); |
|
|
|
|
|
|
|
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); |
|
|
|
ge_root_model->SetModelName("test_name"); |
|
|
|
GeModelPtr ge_sub_model = make_shared<GeModel>(); |
|
|
|
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); |
|
|
|
Buffer weights_buffer(1024, 0x76); |
|
|
|
ge_sub_model->SetWeight(weights_buffer); |
|
|
|
std::vector<uint64_t> weights_value{101, 102}; |
|
|
|
ge_sub_model->SetWeight(Buffer::CopyFrom((uint8_t *)weights_value.data(), weights_value.size() * sizeof(uint64_t))); |
|
|
|
ge_sub_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); |
|
|
|
|
|
|
|
HybridModel hybrid_model(ge_root_model); |
|
|
|
HybridModelBuilder hybrid_model_builder(hybrid_model); |
|
|
|
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); |
|
|
|
|
|
|
|
GraphExecutionContext graph_context; |
|
|
|
graph_context.model = &hybrid_model; |
|
|
|
graph_context.allocator = NpuMemoryAllocator::GetAllocator(0); |
|
|
|
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); |
|
|
|
ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); |
|
|
|
|
|
|
|
uint64_t value_0 = 110; |
|
|
|
uint64_t value_0 = 101; // Enter used for Less, will pass this value to switch. |
|
|
|
TensorValue in_tensor0(&value_0, sizeof(value_0)); |
|
|
|
uint64_t value_1 = 110; |
|
|
|
TensorValue in_tensor1(&value_1, sizeof(value_1)); |
|
|
@@ -242,8 +241,27 @@ TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { |
|
|
|
TensorUtils::SetSize(*tensor_desc, 64); |
|
|
|
const std::vector<ConstGeTensorDescPtr> input_desc{ tensor_desc, tensor_desc }; |
|
|
|
|
|
|
|
GraphExecutionContext graph_context; |
|
|
|
graph_context.model = &hybrid_model; |
|
|
|
graph_context.allocator = NpuMemoryAllocator::GetAllocator(0); |
|
|
|
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); |
|
|
|
ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); |
|
|
|
|
|
|
|
const auto node_it_t = hybrid_model.node_items_.find(switch_t); |
|
|
|
const auto node_it_f = hybrid_model.node_items_.find(switch_f); |
|
|
|
ASSERT_NE(hybrid_model.node_items_.end(), node_it_t); |
|
|
|
ASSERT_NE(hybrid_model.node_items_.end(), node_it_f); |
|
|
|
|
|
|
|
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context); |
|
|
|
ASSERT_EQ(executor.ExecuteAsync(inputs, input_desc, outputs), SUCCESS); |
|
|
|
ASSERT_EQ(executor.Synchronize(), SUCCESS); |
|
|
|
|
|
|
|
const auto state_it_t = executor.subgraph_context_->node_states_.find(node_it_t->second.get()); |
|
|
|
const auto state_it_f = executor.subgraph_context_->node_states_.find(node_it_f->second.get()); |
|
|
|
ASSERT_NE(executor.subgraph_context_->node_states_.end(), state_it_t); |
|
|
|
ASSERT_NE(executor.subgraph_context_->node_states_.end(), state_it_f); |
|
|
|
ASSERT_EQ(state_it_t->second->GetSwitchIndex(), 1); |
|
|
|
ASSERT_EQ(state_it_f->second->GetSwitchIndex(), 0); |
|
|
|
ASSERT_EQ(graph_context.callback_manager->Destroy(), SUCCESS); |
|
|
|
} |
|
|
|
} // namespace ge |