Browse Source

Fix NodeState for UT

tags/v1.3.0
zhangxiaokun 3 years ago
parent
commit
f578e8fff4
2 changed files with 23 additions and 20 deletions
  1. +11
    -12
      tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc
  2. +12
    -8
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 11
- 12
tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc View File

@@ -83,15 +83,14 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) {
execution_context.profiling_level = 1;
SubgraphContext subgraph_context(nullptr, &execution_context);

NodeState node_state(*node_item, &subgraph_context);

ExecutionEngine execution_engine;
ASSERT_TRUE(node_state.GetTaskContext() != nullptr);
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
ASSERT_TRUE(node_state->GetTaskContext() != nullptr);

std::function<void()> callback;
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context);
executor.InitCallback(&node_state, callback);
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
executor.InitCallback(node_state.get(), callback);
ExecutionEngine execution_engine;
EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
}

TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) {
@@ -115,18 +114,18 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) {
execution_context.model = &hybrid_model;
SubgraphContext subgraph_context(nullptr, &execution_context);

NodeState node_state(*node_item, &subgraph_context);
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
uint32_t task_id = 0;
uint32_t stream_id = 1;
std::string task_type = "rts";
uint32_t block_dim = 0;
node_state.GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim);
node_state->GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim);

ExecutionEngine execution_engine;
ASSERT_TRUE(node_state.GetTaskContext() != nullptr);
ASSERT_TRUE(node_state->GetTaskContext() != nullptr);

std::function<void()> callback;
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context);
executor.InitCallback(&node_state, callback);
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
executor.InitCallback(node_state.get(), callback);
ExecutionEngine execution_engine;
EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
}

+ 12
- 8
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -160,9 +160,9 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) {

GraphExecutionContext execution_context;
SubgraphContext subgraph_context(nullptr, &execution_context);
NodeState node_state(*node_item, &subgraph_context);
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS);
ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state.GetTaskContext()), SUCCESS);
ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state->GetTaskContext()), SUCCESS);
}

TEST_F(UtestGeHybrid, index_taskdefs_failed) {
@@ -475,12 +475,14 @@ TEST_F(UtestGeHybrid, TestTaskContext) {
node_item->output_start = 0;

GraphExecutionContext execution_context;
SubgraphContext subgraph_context(nullptr, &execution_context);
GraphItem graph_item;
SubgraphContext subgraph_context(&graph_item, &execution_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
subgraph_context.all_inputs_.resize(2);
subgraph_context.all_outputs_.resize(1);

NodeState node_state(*node_item, &subgraph_context);
auto task_context = node_state.GetTaskContext();
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
auto task_context = node_state->GetTaskContext();
ASSERT_TRUE(task_context != nullptr);
auto desc = task_context->MutableInputDesc(2);
ASSERT_TRUE(desc == nullptr);
@@ -520,12 +522,14 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_update_args) {
node_item->output_start = 0;

GraphExecutionContext execution_context;
SubgraphContext subgraph_context(nullptr, &execution_context);
GraphItem graph_item;
SubgraphContext subgraph_context(&graph_item, &execution_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
subgraph_context.all_inputs_.resize(2);
subgraph_context.all_outputs_.resize(1);

NodeState node_state(*node_item, &subgraph_context);
auto task_context = node_state.GetTaskContext();
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
auto task_context = node_state->GetTaskContext();

int32_t buffer[1];
aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer));


Loading…
Cancel
Save