diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc index c6fb76ed..fb0f2d69 100755 --- a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc +++ b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc @@ -352,6 +352,10 @@ Status AicpuTfNodeTask::Init(const HybridModel &model) { need_sync_ = true; } GELOGI("Node[%s] init end.", node_name_.c_str()); + auto task_defs = model.GetTaskDefs(node_item_->node); + if (unknown_type_ == DEPEND_COMPUTE) { + GE_CHK_STATUS_RET_NOLOG(SetMemCopyTask((*task_defs)[1])); + } return SUCCESS; } @@ -829,9 +833,6 @@ Status AiCpuNodeExecutor::LoadTask(const HybridModel &model, "Load task for node %s failed.", node->GetName().c_str()); GE_CHK_STATUS_RET(aicpu_task->Init(model), "Node[%s] task init failed.", node->GetName().c_str()); - if (node_item->shape_inference_type == DEPEND_COMPUTE) { - GE_CHK_STATUS_RET_NOLOG(aicpu_task->SetMemCopyTask((*task_defs)[1])); - } task = std::move(aicpu_task); GELOGD("Node[%s] load task end.", node->GetName().c_str()); diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h index c6e63ee0..0a21c6ef 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h +++ b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h @@ -42,10 +42,6 @@ class AicpuNodeTaskBase : public NodeTask { virtual Status Init(const HybridModel &model) = 0; - virtual Status SetMemCopyTask(const domi::TaskDef &task_def) { - return UNSUPPORTED; - } - Status UpdateArgs(TaskContext &context) override; Status ExecuteAsync(TaskContext &context, std::function done_callback) override; @@ -94,8 +90,6 @@ class AicpuTfNodeTask : public AicpuNodeTaskBase { Status Init(const HybridModel &model) override; - Status SetMemCopyTask(const domi::TaskDef &task_def) override; - protected: Status LaunchTask(TaskContext &context) override; @@ -105,6 +99,8 @@ class AicpuTfNodeTask : public AicpuNodeTaskBase { Status UpdateIoAddr(TaskContext &context) override; private: + Status SetMemCopyTask(const domi::TaskDef &task_def); + Status InitForDependComputeTask(); Status UpdateShapeAndDataByResultSummary(TaskContext &context);