| @@ -22,6 +22,7 @@ | |||
| #define private public | |||
| #define protected public | |||
| #include "framework/common/taskdown_common.h" | |||
| #include "hybrid/executor/subgraph_context.h" | |||
| #include "hybrid/node_executor/aicore/aicore_op_task.h" | |||
| #include "init/gelib.h" | |||
| #undef private | |||
| @@ -80,12 +81,88 @@ TEST_F(UtestAiCoreOpTask, Init_failed) { | |||
| DEPEND_SHAPE_RANGE); | |||
| domi::TaskDef task_def; | |||
| task_def.set_type(RT_MODEL_TASK_KERNEL); | |||
| std::vector<uint8_t> args(10, 0); | |||
| std::vector<uint8_t> args(100, 0); | |||
| task_def.mutable_kernel()->set_args(args.data(), args.size()); | |||
| task_def.mutable_kernel()->set_args_size(10); | |||
| task_def.mutable_kernel()->set_args_size(100); | |||
| task_def.mutable_kernel()->mutable_context()->set_kernel_type( | |||
| ccKernelType::TE); | |||
| uint16_t args_offset = 20; | |||
| char *a = reinterpret_cast<char *>(&args_offset); | |||
| task_def.mutable_kernel()->mutable_context()->set_args_offset( | |||
| a, 2 * sizeof(uint16_t)); | |||
| EXPECT_EQ(task1->Init(*op_desc, task_def), ge::PARAM_INVALID); | |||
| dlog_setlevel(0, 3, 0); | |||
| } | |||
| TEST_F(UtestAiCoreOpTask, Init_success) { | |||
| dlog_setlevel(0, 0, 0); | |||
| std::unique_ptr<AiCoreOpTask> task1(new AiCoreOpTask()); | |||
| OpDescPtr op_desc = CreateOpDesc("Add", "Add", 2, 1); | |||
| ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, | |||
| DEPEND_SHAPE_RANGE); | |||
| domi::TaskDef task_def; | |||
| task_def.set_type(RT_MODEL_TASK_KERNEL); | |||
| std::vector<uint8_t> args(100, 0); | |||
| task_def.mutable_kernel()->set_args(args.data(), args.size()); | |||
| task_def.mutable_kernel()->set_args_size(100); | |||
| task_def.mutable_kernel()->mutable_context()->set_kernel_type( | |||
| ccKernelType::TE); | |||
| uint16_t args_offset = 20; | |||
| char *a = reinterpret_cast<char *>(&args_offset); | |||
| task_def.mutable_kernel()->mutable_context()->set_args_offset( | |||
| a, 2 * sizeof(uint16_t)); | |||
| EXPECT_EQ(task1->Init(*op_desc, task_def), ge::SUCCESS); | |||
| dlog_setlevel(0, 3, 0); | |||
| } | |||
| TEST_F(UtestAiCoreOpTask, UpdateOutputsShape_success) { | |||
| dlog_setlevel(0, 0, 0); | |||
| std::unique_ptr<AiCoreOpTask> task1(new AiCoreOpTask()); | |||
| OpDescPtr op_desc = CreateOpDesc("Add", "Add", 2, 1); | |||
| ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, | |||
| DEPEND_SHAPE_RANGE); | |||
| domi::TaskDef task_def; | |||
| task_def.set_type(RT_MODEL_TASK_KERNEL); | |||
| std::vector<uint8_t> args(100, 0); | |||
| task_def.mutable_kernel()->set_args(args.data(), args.size()); | |||
| task_def.mutable_kernel()->set_args_size(100); | |||
| task_def.mutable_kernel()->mutable_context()->set_kernel_type( | |||
| ccKernelType::TE); | |||
| uint16_t args_offset = 20; | |||
| char *a = reinterpret_cast<char *>(&args_offset); | |||
| task_def.mutable_kernel()->mutable_context()->set_args_offset( | |||
| a, 2 * sizeof(uint16_t)); | |||
| EXPECT_EQ(task1->Init(*op_desc, task_def), ge::SUCCESS); | |||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
| NodePtr node = graph->AddNode(op_desc); | |||
| std::unique_ptr<NodeItem> new_node; | |||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||
| NodeItem *node_item = new_node.get(); | |||
| node_item->input_start = 0; | |||
| node_item->output_start = 0; | |||
| node_item->is_dynamic = true; | |||
| node_item->shape_inference_type = DEPEND_SHAPE_RANGE; | |||
| GraphItem graph_item; | |||
| graph_item.node_items_.emplace_back(node_item); | |||
| graph_item.total_inputs_ = 2; | |||
| graph_item.total_outputs_ = 1; | |||
| GraphExecutionContext graph_context; | |||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||
| graph_context.callback_manager = | |||
| std::unique_ptr<CallbackManager>(new CallbackManager()); | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto outputs_shape = | |||
| reinterpret_cast<uint32_t(*)[1]>(task1->shape_buffer_->GetData()); | |||
| outputs_shape[0][0] = 2; | |||
| outputs_shape[0][1] = 1; | |||
| outputs_shape[0][2] = 2; | |||
| ASSERT_EQ(task1->UpdateOutputsShape(*node_state->GetTaskContext()), SUCCESS); | |||
| dlog_setlevel(0, 3, 0); | |||
| } | |||
| } // namespace ge | |||