| @@ -83,15 +83,14 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { | |||||
| execution_context.profiling_level = 1; | execution_context.profiling_level = 1; | ||||
| SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
| NodeState node_state(*node_item, &subgraph_context); | |||||
| ExecutionEngine execution_engine; | |||||
| ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||||
| ASSERT_TRUE(node_state->GetTaskContext() != nullptr); | |||||
| std::function<void()> callback; | std::function<void()> callback; | ||||
| SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | ||||
| executor.InitCallback(&node_state, callback); | |||||
| EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
| executor.InitCallback(node_state.get(), callback); | |||||
| ExecutionEngine execution_engine; | |||||
| EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
| } | } | ||||
| TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | ||||
| @@ -115,18 +114,18 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||||
| execution_context.model = &hybrid_model; | execution_context.model = &hybrid_model; | ||||
| SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
| NodeState node_state(*node_item, &subgraph_context); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||||
| uint32_t task_id = 0; | uint32_t task_id = 0; | ||||
| uint32_t stream_id = 1; | uint32_t stream_id = 1; | ||||
| std::string task_type = "rts"; | std::string task_type = "rts"; | ||||
| uint32_t block_dim = 0; | uint32_t block_dim = 0; | ||||
| node_state.GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); | |||||
| node_state->GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); | |||||
| ExecutionEngine execution_engine; | |||||
| ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | |||||
| ASSERT_TRUE(node_state->GetTaskContext() != nullptr); | |||||
| std::function<void()> callback; | std::function<void()> callback; | ||||
| SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | ||||
| executor.InitCallback(&node_state, callback); | |||||
| EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
| executor.InitCallback(node_state.get(), callback); | |||||
| ExecutionEngine execution_engine; | |||||
| EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||||
| } | } | ||||
| @@ -160,9 +160,9 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) { | |||||
| GraphExecutionContext execution_context; | GraphExecutionContext execution_context; | ||||
| SubgraphContext subgraph_context(nullptr, &execution_context); | SubgraphContext subgraph_context(nullptr, &execution_context); | ||||
| NodeState node_state(*node_item, &subgraph_context); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||||
| ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS); | ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS); | ||||
| ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state.GetTaskContext()), SUCCESS); | |||||
| ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state->GetTaskContext()), SUCCESS); | |||||
| } | } | ||||
| TEST_F(UtestGeHybrid, index_taskdefs_failed) { | TEST_F(UtestGeHybrid, index_taskdefs_failed) { | ||||
| @@ -475,12 +475,14 @@ TEST_F(UtestGeHybrid, TestTaskContext) { | |||||
| node_item->output_start = 0; | node_item->output_start = 0; | ||||
| GraphExecutionContext execution_context; | GraphExecutionContext execution_context; | ||||
| SubgraphContext subgraph_context(nullptr, &execution_context); | |||||
| GraphItem graph_item; | |||||
| SubgraphContext subgraph_context(&graph_item, &execution_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| subgraph_context.all_inputs_.resize(2); | subgraph_context.all_inputs_.resize(2); | ||||
| subgraph_context.all_outputs_.resize(1); | subgraph_context.all_outputs_.resize(1); | ||||
| NodeState node_state(*node_item, &subgraph_context); | |||||
| auto task_context = node_state.GetTaskContext(); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||||
| auto task_context = node_state->GetTaskContext(); | |||||
| ASSERT_TRUE(task_context != nullptr); | ASSERT_TRUE(task_context != nullptr); | ||||
| auto desc = task_context->MutableInputDesc(2); | auto desc = task_context->MutableInputDesc(2); | ||||
| ASSERT_TRUE(desc == nullptr); | ASSERT_TRUE(desc == nullptr); | ||||
| @@ -520,12 +522,14 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_update_args) { | |||||
| node_item->output_start = 0; | node_item->output_start = 0; | ||||
| GraphExecutionContext execution_context; | GraphExecutionContext execution_context; | ||||
| SubgraphContext subgraph_context(nullptr, &execution_context); | |||||
| GraphItem graph_item; | |||||
| SubgraphContext subgraph_context(&graph_item, &execution_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| subgraph_context.all_inputs_.resize(2); | subgraph_context.all_inputs_.resize(2); | ||||
| subgraph_context.all_outputs_.resize(1); | subgraph_context.all_outputs_.resize(1); | ||||
| NodeState node_state(*node_item, &subgraph_context); | |||||
| auto task_context = node_state.GetTaskContext(); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||||
| auto task_context = node_state->GetTaskContext(); | |||||
| int32_t buffer[1]; | int32_t buffer[1]; | ||||
| aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer)); | aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer)); | ||||