Browse Source

UpdateTiling pre-place

tags/v1.2.0
wxl 3 years ago
parent
commit
bc1f6ca510
5 changed files with 39 additions and 6 deletions
  1. +8
    -0
      ge/hybrid/executor/node_state.cc
  2. +5
    -0
      ge/hybrid/executor/node_state.h
  3. +25
    -4
      ge/hybrid/executor/subgraph_executor.cc
  4. +1
    -1
      ge/hybrid/executor/subgraph_executor.h
  5. +0
    -1
      ge/hybrid/node_executor/node_executor.cc

+ 8
- 0
ge/hybrid/executor/node_state.cc View File

@@ -188,6 +188,14 @@ Status NodeState::WaitForPrepareDone() {
return SUCCESS; return SUCCESS;
} }


void NodeState::SetTaskContext(std::shared_ptr<TaskContext> &task_context) {
task_context_ = task_context;
}

std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
return task_context_;
}

Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) {
GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str());
HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_), "cancelled"); HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_), "cancelled");


+ 5
- 0
ge/hybrid/executor/node_state.h View File

@@ -29,6 +29,7 @@ namespace hybrid {
class NodeTask; class NodeTask;
struct GraphExecutionContext; struct GraphExecutionContext;
class SubgraphContext; class SubgraphContext;
class TaskContext;


class ShapeFuture { class ShapeFuture {
public: public:
@@ -103,6 +104,9 @@ struct NodeState {


Status AwaitInputTensors(GraphExecutionContext &context) const; Status AwaitInputTensors(GraphExecutionContext &context) const;


void SetTaskContext(std::shared_ptr<TaskContext> &task_context);
std::shared_ptr<TaskContext> GetTaskContext();

private: private:
const NodeItem *node_item_ = nullptr; const NodeItem *node_item_ = nullptr;
std::shared_ptr<NodeTask> kernel_task_ = nullptr; std::shared_ptr<NodeTask> kernel_task_ = nullptr;
@@ -110,6 +114,7 @@ struct NodeState {
OpDescPtr op_desc_; OpDescPtr op_desc_;
ShapeInferenceState shape_inference_state_; ShapeInferenceState shape_inference_state_;
SubgraphContext *subgraph_context_; SubgraphContext *subgraph_context_;
std::shared_ptr<TaskContext> task_context_ = nullptr;
std::mutex mu_; std::mutex mu_;
}; };




+ 25
- 4
ge/hybrid/executor/subgraph_executor.cc View File

@@ -232,6 +232,15 @@ Status SubgraphExecutor::PrepareNodes() {
node_state->SetKernelTask(node_item.kernel_task); node_state->SetKernelTask(node_item.kernel_task);
} }
} }
auto unique_task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state->GetKernelTask();
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[%s] NodeTask is null.", node_state->GetName().c_str());
return INTERNAL_ERROR;
}
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContex(shared_task_context);
} }


if (!ready_queue_.Push(p_node_state)) { if (!ready_queue_.Push(p_node_state)) {
@@ -267,6 +276,19 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta
} else { } else {
node_state.SetKernelTask(node_item.kernel_task); node_state.SetKernelTask(node_item.kernel_task);
} }
auto unique_task_context = TaskContext::Create(*node_state.GetNodeItem(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state.GetKernelTask();
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[%s] NodeTask is null.", node_state.GetName().c_str());
return INTERNAL_ERROR;
}
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state.SetTaskContex(shared_task_context);
GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context));
RECORD_COMPILE_EVENT(ctx, node_item.NodeItem().c_str(), "[UpdateTilingData] start");
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws
RECORD_COMPILE_EVENT(ctx, node_item.NodeItem().c_str(), "[UpdateTilingData] end");
return SUCCESS; return SUCCESS;
} }


@@ -295,10 +317,9 @@ Status SubgraphExecutor::LaunchTasks() {
GE_CHK_STATUS_RET_NOLOG(node_state->WaitForPrepareDone()); GE_CHK_STATUS_RET_NOLOG(node_state->WaitForPrepareDone());


GELOGD("[%s] Start to execute.", node_state->GetName().c_str()); GELOGD("[%s] Start to execute.", node_state->GetName().c_str());
auto task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(task_context);
task_context->SetForceInferShape(force_infer_shape_);
auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release());
auto shared_task_context = node_state->GetTaskContext();
GE_CHECK_NOTNULL(shared_task_context);
shared_task_context->SetForceInferShape(force_infer_shape_);
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_), HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_),
"[%s] Execute node failed.", "[%s] Execute node failed.",
node_state->GetName().c_str()); node_state->GetName().c_str());


+ 1
- 1
ge/hybrid/executor/subgraph_executor.h View File

@@ -75,7 +75,7 @@ class SubgraphExecutor {
Status GetOutputs(std::vector<TensorValue> &outputs, std::vector<ConstGeTensorDescPtr> &output_desc); Status GetOutputs(std::vector<TensorValue> &outputs, std::vector<ConstGeTensorDescPtr> &output_desc);


private: private:
static Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state);
Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state);
static Status InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state); static Status InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state);
Status Init(const std::vector<TensorValue> &inputs, Status Init(const std::vector<TensorValue> &inputs,
const std::vector<ConstGeTensorDescPtr> &input_desc); const std::vector<ConstGeTensorDescPtr> &input_desc);


+ 0
- 1
ge/hybrid/node_executor/node_executor.cc View File

@@ -38,7 +38,6 @@ const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE";
} }
Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const {
GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs());
GE_CHK_STATUS_RET_NOLOG(task.UpdateTilingData(context)); // update op_desc before alloc ws
GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces());
GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context));
return SUCCESS; return SUCCESS;


Loading…
Cancel
Save