@@ -83,15 +83,14 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { | |||||
execution_context.profiling_level = 1; | execution_context.profiling_level = 1; | ||||
SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
NodeState node_state(*node_item, &subgraph_context); | |||||
ExecutionEngine execution_engine; | |||||
ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||||
ASSERT_TRUE(node_state->GetTaskContext() != nullptr); | |||||
std::function<void()> callback; | std::function<void()> callback; | ||||
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | ||||
executor.InitCallback(&node_state, callback); | |||||
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
executor.InitCallback(node_state.get(), callback); | |||||
ExecutionEngine execution_engine; | |||||
EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
} | } | ||||
TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | ||||
@@ -115,18 +114,18 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||||
execution_context.model = &hybrid_model; | execution_context.model = &hybrid_model; | ||||
SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
NodeState node_state(*node_item, &subgraph_context); | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||||
uint32_t task_id = 0; | uint32_t task_id = 0; | ||||
uint32_t stream_id = 1; | uint32_t stream_id = 1; | ||||
std::string task_type = "rts"; | std::string task_type = "rts"; | ||||
uint32_t block_dim = 0; | uint32_t block_dim = 0; | ||||
node_state.GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); | |||||
node_state->GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); | |||||
ExecutionEngine execution_engine; | |||||
ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | |||||
ASSERT_TRUE(node_state->GetTaskContext() != nullptr); | |||||
std::function<void()> callback; | std::function<void()> callback; | ||||
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | ||||
executor.InitCallback(&node_state, callback); | |||||
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
executor.InitCallback(node_state.get(), callback); | |||||
ExecutionEngine execution_engine; | |||||
EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
} | } |
@@ -160,9 +160,9 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) { | |||||
GraphExecutionContext execution_context; | GraphExecutionContext execution_context; | ||||
SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
NodeState node_state(*node_item, &subgraph_context); | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||||
ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS); | ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS); | ||||
ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state.GetTaskContext()), SUCCESS); | |||||
ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state->GetTaskContext()), SUCCESS); | |||||
} | } | ||||
TEST_F(UtestGeHybrid, index_taskdefs_failed) { | TEST_F(UtestGeHybrid, index_taskdefs_failed) { | ||||
@@ -475,12 +475,14 @@ TEST_F(UtestGeHybrid, TestTaskContext) { | |||||
node_item->output_start = 0; | node_item->output_start = 0; | ||||
GraphExecutionContext execution_context; | GraphExecutionContext execution_context; | ||||
SubgraphContext subgraph_context(nullptr, &execution_context); | |||||
GraphItem graph_item; | |||||
SubgraphContext subgraph_context(&graph_item, &execution_context); | |||||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
subgraph_context.all_inputs_.resize(2); | subgraph_context.all_inputs_.resize(2); | ||||
subgraph_context.all_outputs_.resize(1); | subgraph_context.all_outputs_.resize(1); | ||||
NodeState node_state(*node_item, &subgraph_context); | |||||
auto task_context = node_state.GetTaskContext(); | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||||
auto task_context = node_state->GetTaskContext(); | |||||
ASSERT_TRUE(task_context != nullptr); | ASSERT_TRUE(task_context != nullptr); | ||||
auto desc = task_context->MutableInputDesc(2); | auto desc = task_context->MutableInputDesc(2); | ||||
ASSERT_TRUE(desc == nullptr); | ASSERT_TRUE(desc == nullptr); | ||||
@@ -520,12 +522,14 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_update_args) { | |||||
node_item->output_start = 0; | node_item->output_start = 0; | ||||
GraphExecutionContext execution_context; | GraphExecutionContext execution_context; | ||||
SubgraphContext subgraph_context(nullptr, &execution_context); | |||||
GraphItem graph_item; | |||||
SubgraphContext subgraph_context(&graph_item, &execution_context); | |||||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
subgraph_context.all_inputs_.resize(2); | subgraph_context.all_inputs_.resize(2); | ||||
subgraph_context.all_outputs_.resize(1); | subgraph_context.all_outputs_.resize(1); | ||||
NodeState node_state(*node_item, &subgraph_context); | |||||
auto task_context = node_state.GetTaskContext(); | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||||
auto task_context = node_state->GetTaskContext(); | |||||
int32_t buffer[1]; | int32_t buffer[1]; | ||||
aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer)); | aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer)); | ||||