Browse Source

Remove gentask in DEPEND_COMPUTE task executor.

tags/v1.2.0
unknown 3 years ago
parent
commit
33945b054b
2 changed files with 6 additions and 9 deletions
  1. +4
    -3
      ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc
  2. +2
    -6
      ge/hybrid/node_executor/aicpu/aicpu_node_executor.h

+ 4
- 3
ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc View File

@@ -352,6 +352,10 @@ Status AicpuTfNodeTask::Init(const HybridModel &model) {
need_sync_ = true; need_sync_ = true;
} }
GELOGI("Node[%s] init end.", node_name_.c_str()); GELOGI("Node[%s] init end.", node_name_.c_str());
auto task_defs = model.GetTaskDefs(node_item_->node);
if (unknown_type_ == DEPEND_COMPUTE) {
GE_CHK_STATUS_RET_NOLOG(SetMemCopyTask((*task_defs)[1]));
}
return SUCCESS; return SUCCESS;
} }


@@ -829,9 +833,6 @@ Status AiCpuNodeExecutor::LoadTask(const HybridModel &model,
"Load task for node %s failed.", node->GetName().c_str()); "Load task for node %s failed.", node->GetName().c_str());


GE_CHK_STATUS_RET(aicpu_task->Init(model), "Node[%s] task init failed.", node->GetName().c_str()); GE_CHK_STATUS_RET(aicpu_task->Init(model), "Node[%s] task init failed.", node->GetName().c_str());
if (node_item->shape_inference_type == DEPEND_COMPUTE) {
GE_CHK_STATUS_RET_NOLOG(aicpu_task->SetMemCopyTask((*task_defs)[1]));
}


task = std::move(aicpu_task); task = std::move(aicpu_task);
GELOGD("Node[%s] load task end.", node->GetName().c_str()); GELOGD("Node[%s] load task end.", node->GetName().c_str());


+ 2
- 6
ge/hybrid/node_executor/aicpu/aicpu_node_executor.h View File

@@ -42,10 +42,6 @@ class AicpuNodeTaskBase : public NodeTask {


virtual Status Init(const HybridModel &model) = 0; virtual Status Init(const HybridModel &model) = 0;


virtual Status SetMemCopyTask(const domi::TaskDef &task_def) {
return UNSUPPORTED;
}

Status UpdateArgs(TaskContext &context) override; Status UpdateArgs(TaskContext &context) override;


Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
@@ -94,8 +90,6 @@ class AicpuTfNodeTask : public AicpuNodeTaskBase {


Status Init(const HybridModel &model) override; Status Init(const HybridModel &model) override;


Status SetMemCopyTask(const domi::TaskDef &task_def) override;

protected: protected:


Status LaunchTask(TaskContext &context) override; Status LaunchTask(TaskContext &context) override;
@@ -105,6 +99,8 @@ class AicpuTfNodeTask : public AicpuNodeTaskBase {
Status UpdateIoAddr(TaskContext &context) override; Status UpdateIoAddr(TaskContext &context) override;


private: private:
Status SetMemCopyTask(const domi::TaskDef &task_def);

Status InitForDependComputeTask(); Status InitForDependComputeTask();


Status UpdateShapeAndDataByResultSummary(TaskContext &context); Status UpdateShapeAndDataByResultSummary(TaskContext &context);


Loading…
Cancel
Save