Browse Source

add ut002

tags/v1.3.0
zhangxiaokun 4 years ago
parent
commit
fd4653e3c2
1 changed files with 39 additions and 21 deletions
  1. +39
    -21
      tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc

+ 39
- 21
tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc View File

@@ -81,7 +81,7 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string
return graph.AddNode(op_desc);
}

static void CreateSimpleCondGraph(ComputeGraph &graph) {
static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePtr &switch_f) {
/*******************************************************************************
* |
* Merge
@@ -107,19 +107,21 @@ static void CreateSimpleCondGraph(ComputeGraph &graph) {
const auto data1 = CreateNode(graph, "data1", DATA, 1, 1);
data0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");
data1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");
AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0);
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1);

const auto const0 = CreateNode(graph, "const", CONSTANT, 0, 1);
const auto const1 = CreateNode(graph, "const1", CONSTANT, 0, 1);
const0->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");
const1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");
{
uint64_t const_value = 0;
uint64_t const_value = 101;
const auto op_desc = const0->GetOpDesc();
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t));
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight);
}
{
uint64_t const_value = 1;
uint64_t const_value = 101;
const auto op_desc = const1->GetOpDesc();
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t));
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight);
@@ -128,10 +130,10 @@ static void CreateSimpleCondGraph(ComputeGraph &graph) {
const auto less1 = CreateNode(graph, "less", ENTER, 2, 1);

const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0);
const auto switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0);
const auto switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0);
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1);
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 1);
switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0);
switch_f = CreateNode(graph, "switch_f", STREAMSWITCH, 2, 0);
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true.
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL);

const auto add1 = CreateNode(graph, "add", ENTER, 2, 1);
const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1);
@@ -188,10 +190,6 @@ TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) {
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS);

GraphExecutionContext graph_context;
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());
graph_context.model = &hybrid_model;

uint64_t value_0 = 110;
TensorValue in_tensor0(&value_0, sizeof(value_0));
const std::vector<TensorValue> inputs{ in_tensor0 };
@@ -203,33 +201,34 @@ TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) {
auto input_desc = output0->GetOpDesc()->GetInputDescPtr(0);
const std::vector<ConstGeTensorDescPtr> input_descs{ input_desc };

GraphExecutionContext graph_context;
graph_context.model = &hybrid_model;
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());

SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context);
ASSERT_EQ(executor.ExecuteAsync(inputs, input_descs, outputs), SUCCESS);
ASSERT_EQ(executor.Synchronize(), SUCCESS);
}

TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
CreateSimpleCondGraph(*graph);
NodePtr switch_t = nullptr;
NodePtr switch_f = nullptr;
CreateSimpleCondGraph(*graph, switch_t, switch_f);

GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
GeModelPtr ge_sub_model = make_shared<GeModel>();
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
Buffer weights_buffer(1024, 0x76);
ge_sub_model->SetWeight(weights_buffer);
std::vector<uint64_t> weights_value{101, 102};
ge_sub_model->SetWeight(Buffer::CopyFrom((uint8_t *)weights_value.data(), weights_value.size() * sizeof(uint64_t)));
ge_sub_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph));

HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS);

GraphExecutionContext graph_context;
graph_context.model = &hybrid_model;
graph_context.allocator = NpuMemoryAllocator::GetAllocator(0);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());
ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS);

uint64_t value_0 = 110;
uint64_t value_0 = 101; // Enter used for Less, will pass this value to switch.
TensorValue in_tensor0(&value_0, sizeof(value_0));
uint64_t value_1 = 110;
TensorValue in_tensor1(&value_1, sizeof(value_1));
@@ -242,8 +241,27 @@ TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) {
TensorUtils::SetSize(*tensor_desc, 64);
const std::vector<ConstGeTensorDescPtr> input_desc{ tensor_desc, tensor_desc };

GraphExecutionContext graph_context;
graph_context.model = &hybrid_model;
graph_context.allocator = NpuMemoryAllocator::GetAllocator(0);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());
ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS);

const auto node_it_t = hybrid_model.node_items_.find(switch_t);
const auto node_it_f = hybrid_model.node_items_.find(switch_f);
ASSERT_NE(hybrid_model.node_items_.end(), node_it_t);
ASSERT_NE(hybrid_model.node_items_.end(), node_it_f);

SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &graph_context);
ASSERT_EQ(executor.ExecuteAsync(inputs, input_desc, outputs), SUCCESS);
ASSERT_EQ(executor.Synchronize(), SUCCESS);

const auto state_it_t = executor.subgraph_context_->node_states_.find(node_it_t->second.get());
const auto state_it_f = executor.subgraph_context_->node_states_.find(node_it_f->second.get());
ASSERT_NE(executor.subgraph_context_->node_states_.end(), state_it_t);
ASSERT_NE(executor.subgraph_context_->node_states_.end(), state_it_f);
ASSERT_EQ(state_it_t->second->GetSwitchIndex(), 1);
ASSERT_EQ(state_it_f->second->GetSwitchIndex(), 0);
ASSERT_EQ(graph_context.callback_manager->Destroy(), SUCCESS);
}
} // namespace ge

Loading…
Cancel
Save