diff --git a/ge/single_op/single_op.h b/ge/single_op/single_op.h index 94d7227b..7e05dd5f 100755 --- a/ge/single_op/single_op.h +++ b/ge/single_op/single_op.h @@ -92,6 +92,7 @@ class DynamicSingleOp { rtStream_t stream_ = nullptr; size_t num_inputs_ = 0; size_t num_outputs_ = 0; + ComputeGraphPtr compute_graph_; }; } // namespace ge #endif // GE_SINGLE_OP_SINGLE_OP_H_ diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index e5d15beb..7f42f03c 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -529,44 +529,14 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { return BuildTaskList(&resource, single_op); } -Status SingleOpModel::BuildModelTaskKernel(StreamResource *stream_resource, const TaskDef &task_def, - DynamicSingleOp &single_op) { - auto task_type = static_cast(task_def.type()); - const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() : - task_def.kernel_with_handle().context(); +Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) { + auto ge_model = model_helper_.GetGeModel(); + GE_CHECK_NOTNULL(ge_model); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type == ccKernelType::TE) { - GELOGD("Building TBE task."); - TbeOpTask *tbe_task = nullptr; - GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def, &tbe_task)); - tbe_task->SetModelArgs(model_name_, model_id_); - if (tbe_task->tiling_buffer_ != nullptr) { - GELOGD("tiling buffer is not nullptr."); - tbe_task->stream_resource_ = stream_resource; - } - single_op.op_task_.reset(tbe_task); - } else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) { - GELOGD("Building AICPU_CC task"); - OpTask *task = nullptr; - uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; - GELOGI("Build dynamic singleOp CCTask, kernel_id = %lu", dynamic_singleop_kernel_id); - GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task, dynamic_singleop_kernel_id)); - task->SetModelArgs(model_name_, model_id_); - single_op.op_task_.reset(task); - } else { - GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, - "[Check][Param:TaskDef]Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", - context.kernel_type()); - REPORT_INNER_ERROR("E19999", - "BuildModelTaskKernel fail for got:%u not supported, Only TBE, AI_CPU, CUST_AI_CPU kernel are supported.", - context.kernel_type()); - return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID; - } - return SUCCESS; -} + auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); + GE_CHECK_NOTNULL(compute_graph); + single_op.compute_graph_ = compute_graph; -Status SingleOpModel::BuildTaskListForDynamicOp(StreamResource *stream_resource, DynamicSingleOp &single_op) { if (tbe_tasks_.size() > 0) { const auto &task_def = tbe_tasks_[0]; GELOGD("Building TBE task."); diff --git a/ge/single_op/single_op_model.h b/ge/single_op/single_op_model.h index b5198e3d..45616d9a 100755 --- a/ge/single_op/single_op_model.h +++ b/ge/single_op/single_op_model.h @@ -71,8 +71,6 @@ class SingleOpModel { Status BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask **task); Status BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, uint64_t kernel_id); Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task, uint64_t kernel_id); - Status BuildModelTaskKernel(StreamResource *stream_resource, const domi::TaskDef &task_def, - DynamicSingleOp &single_op); static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam ¶m); void ParseArgTable(OpTask *task, SingleOp &op); diff --git a/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc index a6f5c2de..1d5bbb3d 100644 --- a/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc @@ -87,6 +87,7 @@ TEST_F(NodeExecutorTest, TestGetOrCreateExecutor) { TEST_F(NodeExecutorTest, TestInitAndFinalize) { auto &manager = NodeExecutorManager::GetInstance(); manager.FinalizeExecutors(); + manager.FinalizeExecutors(); manager.EnsureInitialized(); manager.EnsureInitialized(); const NodeExecutor *executor = nullptr; @@ -97,7 +98,7 @@ TEST_F(NodeExecutorTest, TestInitAndFinalize) { manager.FinalizeExecutors(); ASSERT_FALSE(manager.executors_.empty()); manager.FinalizeExecutors(); - // ASSERT_TRUE(manager.executors_.empty()); - // ASSERT_TRUE(finalized); + ASSERT_TRUE(manager.executors_.empty()); + ASSERT_TRUE(finalized); } } // namespace ge