diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 171ddaf3..00921705 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -188,6 +188,14 @@ Status NodeState::WaitForPrepareDone() { return SUCCESS; } +void NodeState::SetTaskContext(std::shared_ptr &task_context) { + task_context_ = task_context; +} + +std::shared_ptr NodeState::GetTaskContext() { + return task_context_; +} + Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_), "cancelled"); diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index 02a362b4..c68a19ac 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -29,6 +29,7 @@ namespace hybrid { class NodeTask; struct GraphExecutionContext; class SubgraphContext; +class TaskContext; class ShapeFuture { public: @@ -103,6 +104,9 @@ struct NodeState { Status AwaitInputTensors(GraphExecutionContext &context) const; + void SetTaskContext(std::shared_ptr &task_context); + std::shared_ptr GetTaskContext(); + private: const NodeItem *node_item_ = nullptr; std::shared_ptr kernel_task_ = nullptr; @@ -110,6 +114,7 @@ struct NodeState { OpDescPtr op_desc_; ShapeInferenceState shape_inference_state_; SubgraphContext *subgraph_context_; + std::shared_ptr task_context_ = nullptr; std::mutex mu_; }; diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index f7b063c7..8f7334de 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -232,6 +232,15 @@ Status SubgraphExecutor::PrepareNodes() { node_state->SetKernelTask(node_item.kernel_task); } } + auto unique_task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get()); + GE_CHECK_NOTNULL(unique_task_context); + const auto &task = node_state->GetKernelTask(); + if (task == nullptr) { + GELOGE(INTERNAL_ERROR, "[%s] NodeTask is null.", node_state->GetName().c_str()); + return INTERNAL_ERROR; + } + auto shared_task_context = std::shared_ptr(unique_task_context.release()); + node_state->SetTaskContex(shared_task_context); } if (!ready_queue_.Push(p_node_state)) { @@ -267,6 +276,19 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta } else { node_state.SetKernelTask(node_item.kernel_task); } + auto unique_task_context = TaskContext::Create(*node_state.GetNodeItem(), context_, subgraph_context_.get()); + GE_CHECK_NOTNULL(unique_task_context); + const auto &task = node_state.GetKernelTask(); + if (task == nullptr) { + GELOGE(INTERNAL_ERROR, "[%s] NodeTask is null.", node_state.GetName().c_str()); + return INTERNAL_ERROR; + } + auto shared_task_context = std::shared_ptr(unique_task_context.release()); + node_state.SetTaskContex(shared_task_context); + GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); + RECORD_COMPILE_EVENT(ctx, node_item.NodeItem().c_str(), "[UpdateTilingData] start"); + GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws + RECORD_COMPILE_EVENT(ctx, node_item.NodeItem().c_str(), "[UpdateTilingData] end"); return SUCCESS; } @@ -295,10 +317,9 @@ Status SubgraphExecutor::LaunchTasks() { GE_CHK_STATUS_RET_NOLOG(node_state->WaitForPrepareDone()); GELOGD("[%s] Start to execute.", node_state->GetName().c_str()); - auto task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get()); - GE_CHECK_NOTNULL(task_context); - task_context->SetForceInferShape(force_infer_shape_); - auto shared_task_context = std::shared_ptr(task_context.release()); + auto shared_task_context = node_state->GetTaskContext(); + GE_CHECK_NOTNULL(shared_task_context); + shared_task_context->SetForceInferShape(force_infer_shape_); HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_), "[%s] Execute node failed.", node_state->GetName().c_str()); diff --git a/ge/hybrid/executor/subgraph_executor.h b/ge/hybrid/executor/subgraph_executor.h index d1949947..4523e2c4 100644 --- a/ge/hybrid/executor/subgraph_executor.h +++ b/ge/hybrid/executor/subgraph_executor.h @@ -75,7 +75,7 @@ class SubgraphExecutor { Status GetOutputs(std::vector &outputs, std::vector &output_desc); private: - static Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state); + Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state); static Status InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state); Status Init(const std::vector &inputs, const std::vector &input_desc); diff --git a/ge/hybrid/node_executor/node_executor.cc b/ge/hybrid/node_executor/node_executor.cc index 02427b91..12e98160 100755 --- a/ge/hybrid/node_executor/node_executor.cc +++ b/ge/hybrid/node_executor/node_executor.cc @@ -38,7 +38,6 @@ const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; } Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); - GE_CHK_STATUS_RET_NOLOG(task.UpdateTilingData(context)); // update op_desc before alloc ws GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); return SUCCESS;