From: @wan_xuelei Reviewed-by: @xchu42,@ji_chen Signed-off-by: @ji_chentags/v1.2.0
| @@ -188,6 +188,14 @@ Status NodeState::WaitForPrepareDone() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void NodeState::SetTaskContext(std::shared_ptr<TaskContext> &task_context) { | |||||
| task_context_ = task_context; | |||||
| } | |||||
| std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
| return task_context_; | |||||
| } | |||||
| Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { | Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { | ||||
| GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | ||||
| HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_), "cancelled"); | HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_), "cancelled"); | ||||
| @@ -29,6 +29,7 @@ namespace hybrid { | |||||
| class NodeTask; | class NodeTask; | ||||
| struct GraphExecutionContext; | struct GraphExecutionContext; | ||||
| class SubgraphContext; | class SubgraphContext; | ||||
| class TaskContext; | |||||
| class ShapeFuture { | class ShapeFuture { | ||||
| public: | public: | ||||
| @@ -103,6 +104,9 @@ struct NodeState { | |||||
| Status AwaitInputTensors(GraphExecutionContext &context) const; | Status AwaitInputTensors(GraphExecutionContext &context) const; | ||||
| void SetTaskContext(std::shared_ptr<TaskContext> &task_context); | |||||
| std::shared_ptr<TaskContext> GetTaskContext(); | |||||
| private: | private: | ||||
| const NodeItem *node_item_ = nullptr; | const NodeItem *node_item_ = nullptr; | ||||
| std::shared_ptr<NodeTask> kernel_task_ = nullptr; | std::shared_ptr<NodeTask> kernel_task_ = nullptr; | ||||
| @@ -110,6 +114,7 @@ struct NodeState { | |||||
| OpDescPtr op_desc_; | OpDescPtr op_desc_; | ||||
| ShapeInferenceState shape_inference_state_; | ShapeInferenceState shape_inference_state_; | ||||
| SubgraphContext *subgraph_context_; | SubgraphContext *subgraph_context_; | ||||
| std::shared_ptr<TaskContext> task_context_ = nullptr; | |||||
| std::mutex mu_; | std::mutex mu_; | ||||
| }; | }; | ||||
| @@ -231,6 +231,15 @@ Status SubgraphExecutor::PrepareNodes() { | |||||
| } else { | } else { | ||||
| node_state->SetKernelTask(node_item.kernel_task); | node_state->SetKernelTask(node_item.kernel_task); | ||||
| } | } | ||||
| auto unique_task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(unique_task_context); | |||||
| const auto &task = node_state->GetKernelTask(); | |||||
| if (task == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] NodeTask is null.", node_state->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| } | } | ||||
| } | } | ||||
| @@ -267,6 +276,19 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta | |||||
| } else { | } else { | ||||
| node_state.SetKernelTask(node_item.kernel_task); | node_state.SetKernelTask(node_item.kernel_task); | ||||
| } | } | ||||
| auto unique_task_context = TaskContext::Create(*node_state.GetNodeItem(), context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(unique_task_context); | |||||
| const auto &task = node_state.GetKernelTask(); | |||||
| if (task == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] NodeTask is null.", node_state.GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state.SetTaskContext(shared_task_context); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); | |||||
| RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start"); | |||||
| GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws | |||||
| RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end"); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -295,10 +317,9 @@ Status SubgraphExecutor::LaunchTasks() { | |||||
| GE_CHK_STATUS_RET_NOLOG(node_state->WaitForPrepareDone()); | GE_CHK_STATUS_RET_NOLOG(node_state->WaitForPrepareDone()); | ||||
| GELOGD("[%s] Start to execute.", node_state->GetName().c_str()); | GELOGD("[%s] Start to execute.", node_state->GetName().c_str()); | ||||
| auto task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(task_context); | |||||
| task_context->SetForceInferShape(force_infer_shape_); | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release()); | |||||
| auto shared_task_context = node_state->GetTaskContext(); | |||||
| GE_CHECK_NOTNULL(shared_task_context); | |||||
| shared_task_context->SetForceInferShape(force_infer_shape_); | |||||
| HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_), | HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_), | ||||
| "[%s] Execute node failed.", | "[%s] Execute node failed.", | ||||
| node_state->GetName().c_str()); | node_state->GetName().c_str()); | ||||
| @@ -75,7 +75,7 @@ class SubgraphExecutor { | |||||
| Status GetOutputs(std::vector<TensorValue> &outputs, std::vector<ConstGeTensorDescPtr> &output_desc); | Status GetOutputs(std::vector<TensorValue> &outputs, std::vector<ConstGeTensorDescPtr> &output_desc); | ||||
| private: | private: | ||||
| static Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state); | |||||
| Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state); | |||||
| static Status InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state); | static Status InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state); | ||||
| Status Init(const std::vector<TensorValue> &inputs, | Status Init(const std::vector<TensorValue> &inputs, | ||||
| const std::vector<ConstGeTensorDescPtr> &input_desc); | const std::vector<ConstGeTensorDescPtr> &input_desc); | ||||
| @@ -38,7 +38,6 @@ const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; | |||||
| } | } | ||||
| Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { | Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { | ||||
| GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); | GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); | ||||
| GE_CHK_STATUS_RET_NOLOG(task.UpdateTilingData(context)); // update op_desc before alloc ws | |||||
| GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); | GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); | ||||
| GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); | GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); | ||||
| return SUCCESS; | return SUCCESS; | ||||