| @@ -22,6 +22,7 @@ | |||||
| #define private public | #define private public | ||||
| #define protected public | #define protected public | ||||
| #include "framework/common/taskdown_common.h" | #include "framework/common/taskdown_common.h" | ||||
| #include "hybrid/executor/subgraph_context.h" | |||||
| #include "hybrid/node_executor/aicore/aicore_op_task.h" | #include "hybrid/node_executor/aicore/aicore_op_task.h" | ||||
| #include "init/gelib.h" | #include "init/gelib.h" | ||||
| #undef private | #undef private | ||||
| @@ -80,12 +81,88 @@ TEST_F(UtestAiCoreOpTask, Init_failed) { | |||||
| DEPEND_SHAPE_RANGE); | DEPEND_SHAPE_RANGE); | ||||
| domi::TaskDef task_def; | domi::TaskDef task_def; | ||||
| task_def.set_type(RT_MODEL_TASK_KERNEL); | task_def.set_type(RT_MODEL_TASK_KERNEL); | ||||
| std::vector<uint8_t> args(10, 0); | |||||
| std::vector<uint8_t> args(100, 0); | |||||
| task_def.mutable_kernel()->set_args(args.data(), args.size()); | task_def.mutable_kernel()->set_args(args.data(), args.size()); | ||||
| task_def.mutable_kernel()->set_args_size(10); | |||||
| task_def.mutable_kernel()->set_args_size(100); | |||||
| task_def.mutable_kernel()->mutable_context()->set_kernel_type( | task_def.mutable_kernel()->mutable_context()->set_kernel_type( | ||||
| ccKernelType::TE); | ccKernelType::TE); | ||||
| uint16_t args_offset = 20; | |||||
| char *a = reinterpret_cast<char *>(&args_offset); | |||||
| task_def.mutable_kernel()->mutable_context()->set_args_offset( | |||||
| a, 2 * sizeof(uint16_t)); | |||||
| EXPECT_EQ(task1->Init(*op_desc, task_def), ge::PARAM_INVALID); | EXPECT_EQ(task1->Init(*op_desc, task_def), ge::PARAM_INVALID); | ||||
| dlog_setlevel(0, 3, 0); | dlog_setlevel(0, 3, 0); | ||||
| } | } | ||||
| TEST_F(UtestAiCoreOpTask, Init_success) { | |||||
| dlog_setlevel(0, 0, 0); | |||||
| std::unique_ptr<AiCoreOpTask> task1(new AiCoreOpTask()); | |||||
| OpDescPtr op_desc = CreateOpDesc("Add", "Add", 2, 1); | |||||
| ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, | |||||
| DEPEND_SHAPE_RANGE); | |||||
| domi::TaskDef task_def; | |||||
| task_def.set_type(RT_MODEL_TASK_KERNEL); | |||||
| std::vector<uint8_t> args(100, 0); | |||||
| task_def.mutable_kernel()->set_args(args.data(), args.size()); | |||||
| task_def.mutable_kernel()->set_args_size(100); | |||||
| task_def.mutable_kernel()->mutable_context()->set_kernel_type( | |||||
| ccKernelType::TE); | |||||
| uint16_t args_offset = 20; | |||||
| char *a = reinterpret_cast<char *>(&args_offset); | |||||
| task_def.mutable_kernel()->mutable_context()->set_args_offset( | |||||
| a, 2 * sizeof(uint16_t)); | |||||
| EXPECT_EQ(task1->Init(*op_desc, task_def), ge::SUCCESS); | |||||
| dlog_setlevel(0, 3, 0); | |||||
| } | |||||
| TEST_F(UtestAiCoreOpTask, UpdateOutputsShape_success) { | |||||
| dlog_setlevel(0, 0, 0); | |||||
| std::unique_ptr<AiCoreOpTask> task1(new AiCoreOpTask()); | |||||
| OpDescPtr op_desc = CreateOpDesc("Add", "Add", 2, 1); | |||||
| ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, | |||||
| DEPEND_SHAPE_RANGE); | |||||
| domi::TaskDef task_def; | |||||
| task_def.set_type(RT_MODEL_TASK_KERNEL); | |||||
| std::vector<uint8_t> args(100, 0); | |||||
| task_def.mutable_kernel()->set_args(args.data(), args.size()); | |||||
| task_def.mutable_kernel()->set_args_size(100); | |||||
| task_def.mutable_kernel()->mutable_context()->set_kernel_type( | |||||
| ccKernelType::TE); | |||||
| uint16_t args_offset = 20; | |||||
| char *a = reinterpret_cast<char *>(&args_offset); | |||||
| task_def.mutable_kernel()->mutable_context()->set_args_offset( | |||||
| a, 2 * sizeof(uint16_t)); | |||||
| EXPECT_EQ(task1->Init(*op_desc, task_def), ge::SUCCESS); | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| NodePtr node = graph->AddNode(op_desc); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| node_item->is_dynamic = true; | |||||
| node_item->shape_inference_type = DEPEND_SHAPE_RANGE; | |||||
| GraphItem graph_item; | |||||
| graph_item.node_items_.emplace_back(node_item); | |||||
| graph_item.total_inputs_ = 2; | |||||
| graph_item.total_outputs_ = 1; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| graph_context.callback_manager = | |||||
| std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| auto outputs_shape = | |||||
| reinterpret_cast<uint32_t(*)[1]>(task1->shape_buffer_->GetData()); | |||||
| outputs_shape[0][0] = 2; | |||||
| outputs_shape[0][1] = 1; | |||||
| outputs_shape[0][2] = 2; | |||||
| ASSERT_EQ(task1->UpdateOutputsShape(*node_state->GetTaskContext()), SUCCESS); | |||||
| dlog_setlevel(0, 3, 0); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||