From: @isaacxr Reviewed-by: @sheng-nan,@xchu42 Signed-off-by: @ji_chentags/v1.2.0
| @@ -351,6 +351,7 @@ set(TRAIN_SRC_LIST | |||||
| "hybrid/executor/node_done_manager.cc" | "hybrid/executor/node_done_manager.cc" | ||||
| "hybrid/executor/hybrid_profiler.cc" | "hybrid/executor/hybrid_profiler.cc" | ||||
| "hybrid/executor/hybrid_model_executor.cc" | "hybrid/executor/hybrid_model_executor.cc" | ||||
| "hybrid/executor/hybrid_model_pipeline_executor.cc" | |||||
| "hybrid/executor/hybrid_model_async_executor.cc" | "hybrid/executor/hybrid_model_async_executor.cc" | ||||
| "hybrid/executor/hybrid_execution_context.cc" | "hybrid/executor/hybrid_execution_context.cc" | ||||
| "hybrid/executor/subgraph_context.cc" | "hybrid/executor/subgraph_context.cc" | ||||
| @@ -81,6 +81,7 @@ set(SRC_LIST | |||||
| "../hybrid/executor/node_done_manager.cc" | "../hybrid/executor/node_done_manager.cc" | ||||
| "../hybrid/executor/hybrid_profiler.cc" | "../hybrid/executor/hybrid_profiler.cc" | ||||
| "../hybrid/executor/hybrid_model_executor.cc" | "../hybrid/executor/hybrid_model_executor.cc" | ||||
| "../hybrid/executor/hybrid_model_pipeline_executor.cc" | |||||
| "../hybrid/executor/hybrid_model_async_executor.cc" | "../hybrid/executor/hybrid_model_async_executor.cc" | ||||
| "../hybrid/executor/hybrid_execution_context.cc" | "../hybrid/executor/hybrid_execution_context.cc" | ||||
| "../hybrid/executor/subgraph_context.cc" | "../hybrid/executor/subgraph_context.cc" | ||||
| @@ -3032,6 +3032,7 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GE_TIMESTAMP_EVENT_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); | GE_TIMESTAMP_EVENT_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); | ||||
| GE_DUMP(compute_graph, "AfterDynamicShapePartition"); | |||||
| GE_TIMESTAMP_START(GraphPartition); | GE_TIMESTAMP_START(GraphPartition); | ||||
| GraphPartitioner &partitioner = GetCompilerStages(graph_node->GetGraphId()).partitioner; | GraphPartitioner &partitioner = GetCompilerStages(graph_node->GetGraphId()).partitioner; | ||||
| ret = partitioner.Partition(compute_graph, GraphPartitioner::kPartitioning); | ret = partitioner.Partition(compute_graph, GraphPartitioner::kPartitioning); | ||||
| @@ -742,6 +742,12 @@ Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { | |||||
| if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| bool identity_reserved = false; | |||||
| AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CANNOT_BE_DELETED, identity_reserved); | |||||
| if (identity_reserved) { | |||||
| GELOGD("Identity [%s] need to be reserved", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| if (node->GetType() == IDENTITY || node->GetType() == READVARIABLEOP) { | if (node->GetType() == IDENTITY || node->GetType() == READVARIABLEOP) { | ||||
| // split identity | // split identity | ||||
| ret = SplitIdentity(node); | ret = SplitIdentity(node); | ||||
| @@ -52,6 +52,7 @@ Status StagePartitioner::Partition() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GE_DUMP(root_graph_, "BeforeStagePartition"); | |||||
| if (SplitStageLevel() != SUCCESS) { | if (SplitStageLevel() != SUCCESS) { | ||||
| GELOGE(FAILED, "Split graph-stage for graph %s failed.", root_graph_->GetName().c_str()); | GELOGE(FAILED, "Split graph-stage for graph %s failed.", root_graph_->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -74,6 +75,7 @@ Status StagePartitioner::Partition() { | |||||
| "maybe stage_level was not set correctly.", root_graph_->GetName().c_str()); | "maybe stage_level was not set correctly.", root_graph_->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GE_DUMP(root_graph_, "AfterStagePartition"); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -460,6 +460,7 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat | |||||
| .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) | .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) | ||||
| .Build(); | .Build(); | ||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | ||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_CANNOT_BE_DELETED, true); | |||||
| if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { | if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { | ||||
| GELOGE(FAILED, "Insert IDENTITY node %s after %s failed.", name.c_str(), in_node->GetName().c_str()); | GELOGE(FAILED, "Insert IDENTITY node %s after %s failed.", name.c_str(), in_node->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "hybrid_execution_context.h" | #include "hybrid_execution_context.h" | ||||
| #include <atomic> | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -23,7 +24,14 @@ const uint32_t kEndOfSequence = 0x0704000a; | |||||
| const uint32_t kEndOfSequenceNew = 507005; | const uint32_t kEndOfSequenceNew = 507005; | ||||
| const int32_t kModelAbortNormal = 0x0704000e; | const int32_t kModelAbortNormal = 0x0704000e; | ||||
| const int32_t kModelAbortNormalNew = 507024; | const int32_t kModelAbortNormalNew = 507024; | ||||
| std::atomic_ulong context_id_gen {}; | |||||
| } // namespace | } // namespace | ||||
| GraphExecutionContext::GraphExecutionContext() { | |||||
| context_id = context_id_gen++; | |||||
| } | |||||
| void GraphExecutionContext::SetErrorCode(Status error_code) { | void GraphExecutionContext::SetErrorCode(Status error_code) { | ||||
| std::lock_guard<std::mutex> lk(mu); | std::lock_guard<std::mutex> lk(mu); | ||||
| this->status = error_code; | this->status = error_code; | ||||
| @@ -48,11 +48,15 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| struct GraphExecutionContext { | struct GraphExecutionContext { | ||||
| GraphExecutionContext(); | |||||
| ~GraphExecutionContext() = default; | |||||
| void SetErrorCode(Status error_code); | void SetErrorCode(Status error_code); | ||||
| Status GetStatus() const; | Status GetStatus() const; | ||||
| Status Synchronize(rtStream_t rt_stream); | Status Synchronize(rtStream_t rt_stream); | ||||
| uint64_t session_id = 0; | uint64_t session_id = 0; | ||||
| uint64_t context_id = 0; | |||||
| const HybridModel *model = nullptr; | const HybridModel *model = nullptr; | ||||
| const GEThreadLocalContext *ge_context = nullptr; | const GEThreadLocalContext *ge_context = nullptr; | ||||
| rtStream_t stream = nullptr; | rtStream_t stream = nullptr; | ||||
| @@ -67,6 +71,8 @@ struct GraphExecutionContext { | |||||
| std::atomic_bool is_eos_; | std::atomic_bool is_eos_; | ||||
| long profiling_level = 0; | long profiling_level = 0; | ||||
| long iteration = 0; | long iteration = 0; | ||||
| private: | |||||
| Status status = SUCCESS; | Status status = SUCCESS; | ||||
| mutable std::mutex mu; | mutable std::mutex mu; | ||||
| }; | }; | ||||
| @@ -75,7 +81,8 @@ struct GraphExecutionContext { | |||||
| do { \ | do { \ | ||||
| if ((context != nullptr) && (context)->profiler != nullptr) { \ | if ((context != nullptr) && (context)->profiler != nullptr) { \ | ||||
| if (node_name != nullptr) { \ | if (node_name != nullptr) { \ | ||||
| context->profiler->RecordEvent(evt_type, "tid:%lu [%s] [%s] " fmt, GeLog::GetTid(), node_name, category, \ | |||||
| context->profiler->RecordEvent(evt_type, "tid:%lu [%s@%ld] [%s] " fmt, \ | |||||
| GeLog::GetTid(), node_name, context->iteration, category, \ | |||||
| ##__VA_ARGS__); \ | ##__VA_ARGS__); \ | ||||
| } else { \ | } else { \ | ||||
| context->profiler->RecordEvent(evt_type, "tid:%lu [%s] " fmt, GeLog::GetTid(), category, ##__VA_ARGS__); \ | context->profiler->RecordEvent(evt_type, "tid:%lu [%s] " fmt, GeLog::GetTid(), category, ##__VA_ARGS__); \ | ||||
| @@ -25,6 +25,7 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace { | namespace { | ||||
| const int kDataOutputIndex = 0; | const int kDataOutputIndex = 0; | ||||
| const size_t kMinimumPiplineStages = 2; | |||||
| } | } | ||||
| HybridModelAsyncExecutor::HybridModelAsyncExecutor(HybridModel *model) | HybridModelAsyncExecutor::HybridModelAsyncExecutor(HybridModel *model) | ||||
| : model_(model), run_flag_(false) { | : model_(model), run_flag_(false) { | ||||
| @@ -95,7 +96,17 @@ Status HybridModelAsyncExecutor::Init() { | |||||
| executor_ = std::unique_ptr<HybridModelExecutor>(new(std::nothrow) HybridModelExecutor(model_, device_id_, stream_)); | executor_ = std::unique_ptr<HybridModelExecutor>(new(std::nothrow) HybridModelExecutor(model_, device_id_, stream_)); | ||||
| GE_CHECK_NOTNULL(executor_); | GE_CHECK_NOTNULL(executor_); | ||||
| GE_CHK_STATUS_RET(executor_->Init(), "Failed to init hybrid engine"); | GE_CHK_STATUS_RET(executor_->Init(), "Failed to init hybrid engine"); | ||||
| GELOGI("HybridModel stage nums:%zu", model_->GetRootGraphItem()->NumGroups()); | |||||
| if (model_->GetRootGraphItem()->NumGroups() >= kMinimumPiplineStages) { | |||||
| pipe_executor_ = | |||||
| std::unique_ptr<HybridModelPipelineExecutor>(new(std::nothrow) HybridModelPipelineExecutor(model_, device_id_)); | |||||
| GE_CHECK_NOTNULL(pipe_executor_); | |||||
| GE_CHK_STATUS_RET(pipe_executor_->Init(), "Failed to init hybrid engine"); | |||||
| } | |||||
| GE_CHK_STATUS_RET(InitInputDesc(), "Failed to init input tensors"); | GE_CHK_STATUS_RET(InitInputDesc(), "Failed to init input tensors"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -135,7 +146,18 @@ Status HybridModelAsyncExecutor::RunInternal() { | |||||
| CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); | CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); | ||||
| continue, "PreRun failed."); // [No need to check value] | continue, "PreRun failed."); // [No need to check value] | ||||
| ret = executor_->Execute(args); | |||||
| if (pipe_executor_ != nullptr) { | |||||
| GELOGI("HybridModel will execute in pipeline mode"); | |||||
| auto iter_per_run = std::getenv("ITER_NUM"); | |||||
| if (iter_per_run) { | |||||
| args.num_loops = static_cast<int>(strtol(iter_per_run, nullptr, 10)); | |||||
| } | |||||
| ret = pipe_executor_->Execute(args); | |||||
| } else { | |||||
| GELOGI("HybridModel will execute in singleline mode"); | |||||
| ge::GetContext().SetSessionId(executor_->GetContext()->session_id); | |||||
| ret = executor_->Execute(args); | |||||
| } | |||||
| ret = HandleResult(ret, current_data.index, args, data_wrapper->GetOutput()); | ret = HandleResult(ret, current_data.index, args, data_wrapper->GetOutput()); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); | CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
| #include "graph/load/model_manager/data_inputer.h" | #include "graph/load/model_manager/data_inputer.h" | ||||
| #include "hybrid/executor/hybrid_model_executor.h" | #include "hybrid/executor/hybrid_model_executor.h" | ||||
| #include "hybrid/executor/hybrid_model_pipeline_executor.h" | |||||
| #include "runtime/stream.h" | #include "runtime/stream.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -81,6 +82,7 @@ class HybridModelAsyncExecutor { | |||||
| std::atomic_bool run_flag_; | std::atomic_bool run_flag_; | ||||
| std::unique_ptr<DataInputer> data_inputer_; | std::unique_ptr<DataInputer> data_inputer_; | ||||
| std::unique_ptr<HybridModelExecutor> executor_; | std::unique_ptr<HybridModelExecutor> executor_; | ||||
| std::unique_ptr<HybridModelPipelineExecutor> pipe_executor_; | |||||
| std::future<Status> future_; | std::future<Status> future_; | ||||
| uint64_t iterator_count_ = 0; | uint64_t iterator_count_ = 0; | ||||
| @@ -81,13 +81,14 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| args.outputs.clear(); | args.outputs.clear(); | ||||
| HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | ||||
| context_.iteration +=1; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelExecutor::Cleanup() { | Status HybridModelExecutor::Cleanup() { | ||||
| GELOGD("Start to cleanup."); | GELOGD("Start to cleanup."); | ||||
| context_.callback_manager->Destroy(); | context_.callback_manager->Destroy(); | ||||
| RuntimeInferenceContext::DestroyContext(std::to_string(context_.session_id)); | |||||
| RuntimeInferenceContext::DestroyContext(std::to_string(context_.context_id)); | |||||
| GELOGD("Cleanup successfully."); | GELOGD("Cleanup successfully."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -105,7 +106,7 @@ Status HybridModelExecutor::InitExecutionContext() { | |||||
| GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); | GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); | ||||
| context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_); | context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_); | ||||
| GE_CHECK_NOTNULL(context_.allocator); | GE_CHECK_NOTNULL(context_.allocator); | ||||
| context_.callback_manager = std::unique_ptr<CallbackManager>(new(std::nothrow)CallbackManager(stream_)); | |||||
| context_.callback_manager = std::unique_ptr<CallbackManager>(new(std::nothrow)CallbackManager()); | |||||
| GE_CHECK_NOTNULL(context_.callback_manager); | GE_CHECK_NOTNULL(context_.callback_manager); | ||||
| context_.dump_properties = PropertiesManager::Instance().GetDumpProperties(context_.session_id); | context_.dump_properties = PropertiesManager::Instance().GetDumpProperties(context_.session_id); | ||||
| const char *profiling_level = std::getenv(kEnvProfilingLevel); | const char *profiling_level = std::getenv(kEnvProfilingLevel); | ||||
| @@ -126,7 +127,7 @@ Status HybridModelExecutor::InitExecutionContext() { | |||||
| Status HybridModelExecutor::ResetExecutionContext(GraphExecutionContext &context) { | Status HybridModelExecutor::ResetExecutionContext(GraphExecutionContext &context) { | ||||
| GE_CHK_STATUS_RET_NOLOG(context.callback_manager->Init()); | GE_CHK_STATUS_RET_NOLOG(context.callback_manager->Init()); | ||||
| string ctx_id = std::to_string(context.session_id); | |||||
| string ctx_id = std::to_string(context.context_id); | |||||
| RuntimeInferenceContext::DestroyContext(ctx_id); | RuntimeInferenceContext::DestroyContext(ctx_id); | ||||
| GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext"); | GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -32,6 +32,7 @@ class HybridModelExecutor { | |||||
| std::vector<TensorValue> outputs; | std::vector<TensorValue> outputs; | ||||
| std::vector<ConstGeTensorDescPtr> output_desc; | std::vector<ConstGeTensorDescPtr> output_desc; | ||||
| bool is_eos = false; | bool is_eos = false; | ||||
| int num_loops = 10; | |||||
| }; | }; | ||||
| HybridModelExecutor(HybridModel *model, uint32_t device_id, rtStream_t stream); | HybridModelExecutor(HybridModel *model, uint32_t device_id, rtStream_t stream); | ||||
| @@ -0,0 +1,284 @@ | |||||
| #include "hybrid_model_pipeline_executor.h" | |||||
| #include "common/math/math_util.h" | |||||
| #include "graph/ge_context.h" | |||||
| #include "graph/runtime_inference_context.h" | |||||
| namespace ge { | |||||
| namespace hybrid { | |||||
| namespace { | |||||
| constexpr int kNumExecutors = 2; | |||||
| const int kIntBase = 10; | |||||
| const char *const kEnvProfilingLevel = "HYBRID_PROFILING_LEVEL"; | |||||
| } | |||||
| StageExecutor::StageExecutor(int id, HybridModel *model, PipeExecutionConfig *config) | |||||
| : id_(id), model_(model), pipe_config_(config) {} | |||||
| StageExecutor::~StageExecutor() { GELOGD("~StageExecutor(), id = %d", id_); } | |||||
| Status StageExecutor::Init() { | |||||
| GELOGD("[Executor: %d] Start to init StateExecutor", id_); | |||||
| context_.rt_context = pipe_config_->rt_context; | |||||
| GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | |||||
| GE_CHK_RT_RET(rtStreamCreate(&stream_, RT_STREAM_PRIORITY_DEFAULT)); | |||||
| context_.stream = stream_; | |||||
| root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); | |||||
| GE_CHECK_NOTNULL(root_graph_executor_); | |||||
| GELOGD("[Executor: %d] Init stage executor successfully", id_); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::ResetExecutionContext(GraphExecutionContext &context) { | |||||
| GE_CHK_STATUS_RET_NOLOG(context.callback_manager->Init()); | |||||
| string ctx_id = std::to_string(context.context_id); | |||||
| RuntimeInferenceContext::DestroyContext(ctx_id); | |||||
| GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext"); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::Start(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc, | |||||
| int iteration_count) { | |||||
| GELOGD("Start"); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | |||||
| int num_loops = iteration_count / pipe_config_->num_executors; | |||||
| if (id_ < iteration_count % iteration_count) { | |||||
| num_loops += 1; | |||||
| } | |||||
| FMK_INT32_MULCHECK(num_loops, pipe_config_->num_stages); | |||||
| num_loops *= pipe_config_->num_stages; | |||||
| GELOGD("[Executor: %d] loop count = %d", id_, num_loops); | |||||
| for (int loop_idx = 0; loop_idx < num_loops; ++loop_idx) { | |||||
| GELOGD("[Executor: %d] Start to wait for task.", id_); | |||||
| StageTask task_info; | |||||
| task_queue_.Pop(task_info); | |||||
| GELOGD("[Executor: %d] Got task, stage = %d, iteration = %ld", id_, task_info.stage, task_info.iteration); | |||||
| if (task_info.iteration >= pipe_config_->iteration_end) { | |||||
| GELOGE(INTERNAL_ERROR, "[Executor: %d] Unexpected iteration: %d", id_, task_info.iteration); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (task_info.event != nullptr) { | |||||
| GELOGD("[%d] Add StreamWaitEvent", id_); | |||||
| GE_CHK_RT_RET(rtStreamWaitEvent(stream_, task_info.event)); | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %d] [Stage = %d] End", task_info.iteration - 1, | |||||
| task_info.stage); | |||||
| } | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %d] [Stage = %d] Start", task_info.iteration, | |||||
| task_info.stage); | |||||
| if (task_info.stage == 0) { | |||||
| GELOGD("[Executor: %d] To ResetExecutionContext", id_); | |||||
| GE_CHK_STATUS_RET(ResetExecutionContext(context_), "[Executor: %d] Failed to reset context", id_); | |||||
| context_.iteration = task_info.iteration; | |||||
| GE_CHK_STATUS_RET_NOLOG(SetInputs(inputs, input_desc)); | |||||
| } | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Stage = %d] PartialExecuteAsync Start", task_info.stage); | |||||
| GE_CHK_STATUS_RET(root_graph_executor_->PartialExecuteAsync(task_info.stage)); | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Stage = %d] PartialExecuteAsync End", task_info.stage); | |||||
| GELOGD("[Executor: %d] PartialExecuteAsync successfully.", id_); | |||||
| // notify next execution unit | |||||
| StageTask next_task; | |||||
| next_task.stage = task_info.stage; | |||||
| next_task.iteration = task_info.iteration + 1; | |||||
| auto sync_result = Synchronize(); | |||||
| if (sync_result != SUCCESS) { | |||||
| GELOGE(sync_result, "[Executor: %d] Failed to sync result. iteration = %d", id_, task_info.iteration); | |||||
| context_.profiler->Dump(std::cout); | |||||
| context_.callback_manager->Destroy(); | |||||
| RuntimeInferenceContext::DestroyContext(std::to_string(context_.context_id)); | |||||
| return sync_result; | |||||
| } | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %d] [Stage = %d] End", task_info.iteration, task_info.stage); | |||||
| // if not end stage | |||||
| if (task_info.stage >= pipe_config_->num_stages - 1) { | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %d] Schedule End", task_info.iteration); | |||||
| GELOGD("[Executor: %d] End of iteration [%ld]", id_, task_info.iteration); | |||||
| context_.callback_manager->Destroy(); | |||||
| RuntimeInferenceContext::DestroyContext(std::to_string(context_.context_id)); | |||||
| } | |||||
| next_executor_->ExecuteAsync(next_task); | |||||
| GELOGD("[Executor: %d] Push item successfully.", id_); | |||||
| } | |||||
| GELOGD("[Executor: %d] Process task ended.", id_); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::ExecuteAsync(const StageTask &args) { | |||||
| (void)task_queue_.Push(args); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::Synchronize() { | |||||
| auto ret = root_graph_executor_->Synchronize(); | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End, ret = %u", ret); | |||||
| return ret; | |||||
| } | |||||
| HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uint32_t device_id) | |||||
| : model_(model), device_id_(device_id) { | |||||
| config_.num_executors = kNumExecutors; | |||||
| config_.num_stages = model_->GetRootGraphItem()->NumGroups(); | |||||
| config_.device_id = device_id_; | |||||
| } | |||||
| Status StageExecutor::InitExecutionContext() { | |||||
| GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | |||||
| context_.model = model_; | |||||
| context_.session_id = ::ge::GetContext().SessionId(); | |||||
| GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); | |||||
| context_.allocator = NpuMemoryAllocator::GetAllocator(pipe_config_->device_id); | |||||
| GE_CHECK_NOTNULL(context_.allocator); | |||||
| context_.callback_manager = std::unique_ptr<CallbackManager>(new (std::nothrow) CallbackManager()); | |||||
| GE_CHECK_NOTNULL(context_.callback_manager); | |||||
| context_.dump_properties = PropertiesManager::Instance().GetDumpProperties(context_.session_id); | |||||
| if (IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) { | |||||
| context_.trace_enabled = true; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::SetInputs(const vector<TensorValue> &inputs, const vector<ConstGeTensorDescPtr> &input_desc) { | |||||
| root_graph_executor_->InitForPartialExecution(inputs, input_desc); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::GetOutputs(vector<TensorValue> &outputs, vector<ConstGeTensorDescPtr> &output_desc) { | |||||
| return root_graph_executor_->GetOutputs(outputs, output_desc); | |||||
| } | |||||
| void StageExecutor::Reset() { | |||||
| task_queue_.Stop(); | |||||
| task_queue_.Clear(); | |||||
| task_queue_.Restart(); | |||||
| } | |||||
| Status HybridModelPipelineExecutor::Init() { | |||||
| const char *profiling_level = std::getenv(kEnvProfilingLevel); | |||||
| if (profiling_level != nullptr) { | |||||
| context_.profiling_level = std::strtol(profiling_level, nullptr, kIntBase); | |||||
| GELOGD("Got profiling level = %ld", context_.profiling_level); | |||||
| if (context_.profiling_level > 0) { | |||||
| context_.profiler.reset(new (std::nothrow) HybridProfiler()); | |||||
| GE_CHECK_NOTNULL(context_.profiler); | |||||
| } | |||||
| } | |||||
| GELOGD("Number of stages = %d, number of executors = %d", config_.num_stages, config_.num_executors); | |||||
| GE_CHK_RT_RET(rtCtxGetCurrent(&config_.rt_context)); | |||||
| GE_CHK_STATUS_RET_NOLOG(InitStageExecutors()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelPipelineExecutor::InitStageExecutors() { | |||||
| for (int i = 0; i < config_.num_executors; ++i) { | |||||
| auto stage_executor = std::unique_ptr<StageExecutor>(new (std::nothrow) StageExecutor(i, model_, &config_)); | |||||
| GE_CHECK_NOTNULL(stage_executor); | |||||
| GE_CHK_STATUS_RET_NOLOG(stage_executor->Init()); | |||||
| if (context_.profiler != nullptr) { | |||||
| // will call unique_ptr::release later | |||||
| stage_executor->context_.profiler.reset(context_.profiler.get()); | |||||
| stage_executor->context_.profiling_level = context_.profiling_level; | |||||
| } | |||||
| stage_executors_.emplace_back(std::move(stage_executor)); | |||||
| } | |||||
| // build propagation loop | |||||
| for (int i = 0; i < config_.num_executors - 1; ++i) { | |||||
| stage_executors_[i]->SetNext(stage_executors_[i + 1].get()); | |||||
| } | |||||
| stage_executors_[config_.num_executors - 1]->SetNext(stage_executors_[0].get()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelPipelineExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| int loop_count = args.num_loops; | |||||
| GE_CHECK_GE(loop_count, 2); | |||||
| auto &inputs = args.inputs; | |||||
| auto &input_desc = args.input_desc; | |||||
| // Start schedulers | |||||
| std::vector<std::future<Status>> futures; | |||||
| for (size_t i = 0; i < stage_executors_.size(); ++i) { | |||||
| GELOGD("Starting executor %zu", i); | |||||
| auto executor = stage_executors_[i].get(); | |||||
| executor->Reset(); | |||||
| auto future = std::async( | |||||
| [loop_count, executor, inputs, input_desc]() { return executor->Start(inputs, input_desc, loop_count); }); | |||||
| futures.emplace_back(std::move(future)); | |||||
| } | |||||
| // Push initial tasks | |||||
| GELOGD("Start to execute with loops, loop count = %d", loop_count); | |||||
| config_.iteration_end = iteration_ + loop_count; | |||||
| for (int i = 0; i < config_.num_stages; ++i) { | |||||
| StageExecutor::StageTask task_info; | |||||
| task_info.stage = i; | |||||
| task_info.iteration = iteration_; | |||||
| stage_executors_[0]->ExecuteAsync(task_info); | |||||
| } | |||||
| // Wait for end of iterations | |||||
| bool has_error = false; | |||||
| for (size_t i = 0; i < stage_executors_.size(); ++i) { | |||||
| GELOGD("Start to sync result of executor[%zu]", i); | |||||
| auto ret = futures[i].get(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Executor: %zu] Failed to schedule tasks.", i); | |||||
| has_error = true; | |||||
| continue; | |||||
| } | |||||
| ret = stage_executors_[i]->Synchronize(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Executor: %zu] Failed to synchronize result.", i); | |||||
| has_error = true; | |||||
| continue; | |||||
| } | |||||
| } | |||||
| // record for profiling analyzer | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | |||||
| if (context_.profiler != nullptr) { | |||||
| context_.profiler->Dump(std::cout); | |||||
| } | |||||
| iteration_ = config_.iteration_end; | |||||
| if (has_error) { | |||||
| GELOGE(FAILED, "Error occurred while execution"); | |||||
| return FAILED; | |||||
| } | |||||
| auto last_iter_executor_idx = loop_count % stage_executors_.size(); | |||||
| GE_CHK_STATUS_RET(stage_executors_[last_iter_executor_idx]->GetOutputs(args.outputs, args.output_desc), | |||||
| "Failed to get output from executor[%d]", last_iter_executor_idx); | |||||
| return SUCCESS; | |||||
| } | |||||
| HybridModelPipelineExecutor::~HybridModelPipelineExecutor() { | |||||
| GELOGD("~HybridModelPipelineExecutor()"); | |||||
| for (auto &executor : stage_executors_) { | |||||
| (void)executor->context_.profiler.release(); | |||||
| } | |||||
| } | |||||
| } // namespace hybrid | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,88 @@ | |||||
| #ifndef GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_ | |||||
| #define GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_ | |||||
| #include "common/blocking_queue.h" | |||||
| #include "common/thread_pool.h" | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | |||||
| #include "hybrid/executor/rt_callback_manager.h" | |||||
| #include "hybrid/executor/subgraph_executor.h" | |||||
| #include "hybrid_model_executor.h" | |||||
| namespace ge { | |||||
| namespace hybrid { | |||||
| struct PipeExecutionConfig { | |||||
| uint32_t device_id; | |||||
| rtContext_t rt_context; | |||||
| int num_executors; | |||||
| int num_stages; | |||||
| long iteration_end; | |||||
| }; | |||||
| class StageExecutor { | |||||
| public: | |||||
| struct StageTask { | |||||
| rtEvent_t event = nullptr; | |||||
| int stage = 0; | |||||
| long iteration = 0; | |||||
| }; | |||||
| StageExecutor(int id, HybridModel *model, PipeExecutionConfig *config); | |||||
| ~StageExecutor(); | |||||
| Status Init(); | |||||
| void Reset(); | |||||
| Status Start(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc, | |||||
| int loop_count); | |||||
| Status SetInputs(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc); | |||||
| Status ExecuteAsync(const StageTask &args); | |||||
| Status GetOutputs(std::vector<TensorValue> &outputs, std::vector<ConstGeTensorDescPtr> &output_desc); | |||||
| Status Synchronize(); | |||||
| void SetNext(StageExecutor *next_executor) { next_executor_ = next_executor; } | |||||
| private: | |||||
| friend class HybridModelPipelineExecutor; | |||||
| static Status ResetExecutionContext(GraphExecutionContext &context); | |||||
| Status InitExecutionContext(); | |||||
| int id_; | |||||
| HybridModel *model_; | |||||
| PipeExecutionConfig *pipe_config_; | |||||
| BlockingQueue<StageTask> task_queue_; | |||||
| std::unique_ptr<SubgraphExecutor> root_graph_executor_; | |||||
| GraphExecutionContext context_; | |||||
| StageExecutor *next_executor_; | |||||
| rtStream_t stream_ = nullptr; | |||||
| }; | |||||
| class HybridModelPipelineExecutor { | |||||
| public: | |||||
| HybridModelPipelineExecutor(HybridModel *model, uint32_t device_id); | |||||
| ~HybridModelPipelineExecutor(); | |||||
| Status Init(); | |||||
| Status InitStageExecutors(); | |||||
| Status Execute(HybridModelExecutor::ExecuteArgs &args); | |||||
| private: | |||||
| HybridModel *model_; | |||||
| uint32_t device_id_; | |||||
| std::vector<std::unique_ptr<StageExecutor>> stage_executors_; | |||||
| PipeExecutionConfig config_; | |||||
| GraphExecutionContext context_; | |||||
| long iteration_ = 0; | |||||
| }; | |||||
| } // namespace hybrid | |||||
| } // namespace ge | |||||
| #endif // GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_ | |||||
| @@ -24,7 +24,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace { | namespace { | ||||
| const int kMaxEvents = 10000; | |||||
| const int kMaxEvents = 1024 * 500; | |||||
| const int kEventDescMax = 512; | const int kEventDescMax = 512; | ||||
| const int kMaxEventTypes = 8; | const int kMaxEventTypes = 8; | ||||
| const int kIndent = 8; | const int kIndent = 8; | ||||
| @@ -46,11 +46,14 @@ void HybridProfiler::RecordEvent(EventType event_type, const char *fmt, ...) { | |||||
| } | } | ||||
| va_end(args); | va_end(args); | ||||
| std::string event = buf; | |||||
| auto index = counter_++; | auto index = counter_++; | ||||
| if (index >= static_cast<int>(events_.size())) { | |||||
| GELOGE(INTERNAL_ERROR, "index out of range. index = %d, max event size = %zu", index, events_.size()); | |||||
| return; | |||||
| } | |||||
| auto &evt = events_[index]; | auto &evt = events_[index]; | ||||
| evt.timestamp = std::chrono::system_clock::now(); | evt.timestamp = std::chrono::system_clock::now(); | ||||
| evt.desc = std::move(event); | |||||
| evt.desc = std::string(buf); | |||||
| evt.event_type = event_type; | evt.event_type = event_type; | ||||
| } | } | ||||
| @@ -78,7 +81,7 @@ void HybridProfiler::Dump(std::ostream &output_stream) { | |||||
| auto cost_dump = std::chrono::duration_cast<std::chrono::microseconds>(end_dump - start_dump).count(); | auto cost_dump = std::chrono::duration_cast<std::chrono::microseconds>(end_dump - start_dump).count(); | ||||
| output_stream << std::setw(kIndent) << elapsed_dump << "\t\t" << cost_dump | output_stream << std::setw(kIndent) << elapsed_dump << "\t\t" << cost_dump | ||||
| << "\t\t" << "[Dump profiling]" << std::endl; | << "\t\t" << "[Dump profiling]" << std::endl; | ||||
| events_.clear(); | |||||
| Reset(); | |||||
| } | } | ||||
| void HybridProfiler::Reset() { | void HybridProfiler::Reset() { | ||||
| @@ -34,6 +34,14 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||||
| GELOGD("[%s] ShapeInferenceState created, pending shape count = %d", | GELOGD("[%s] ShapeInferenceState created, pending shape count = %d", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| this->num_pending_shapes_); | this->num_pending_shapes_); | ||||
| for (int i = 0; i < node_item.num_inputs; ++i){ | |||||
| input_tensor_desc.emplace_back(std::move(*node_item.MutableInputDesc(i))); | |||||
| } | |||||
| for (int i = 0; i < node_item.num_outputs; ++i){ | |||||
| output_tensor_desc.emplace_back(std::move(*node_item.MutableOutputDesc(i))); | |||||
| } | |||||
| } | } | ||||
| Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { | Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { | ||||
| @@ -56,11 +64,10 @@ Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target | |||||
| tensor_size); | tensor_size); | ||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| auto tensor_desc = node_item.MutableInputDesc(idx); | |||||
| GE_CHECK_NOTNULL(tensor_desc); | |||||
| tensor_desc->SetShape(target.GetShape()); | |||||
| tensor_desc->SetOriginShape(target.GetOriginShape()); | |||||
| (void) TensorUtils::SetSize(*tensor_desc, tensor_size); | |||||
| auto &input_desc = input_tensor_desc[idx]; | |||||
| input_desc.SetShape(target.GetShape()); | |||||
| input_desc.SetOriginShape(target.GetOriginShape()); | |||||
| (void) TensorUtils::SetSize(input_desc, tensor_size); | |||||
| if (--num_pending_shapes_ <= 0) { | if (--num_pending_shapes_ <= 0) { | ||||
| ready_cv_.notify_all(); | ready_cv_.notify_all(); | ||||
| } | } | ||||
| @@ -115,12 +122,27 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
| } | } | ||||
| } | } | ||||
| for (size_t i = 0; i < input_tensor_desc.size(); ++i) { | |||||
| auto dst_tensor_desc = node_item.op_desc->MutableInputDesc(i); | |||||
| if (dst_tensor_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto &tensor_desc = input_tensor_desc[i]; | |||||
| int64_t tensor_size = -1; | |||||
| (void) TensorUtils::GetSize(tensor_desc, tensor_size); | |||||
| dst_tensor_desc->SetShape(tensor_desc.MutableShape()); | |||||
| dst_tensor_desc->SetOriginShape(tensor_desc.GetOriginShape()); | |||||
| (void) TensorUtils::SetSize(*dst_tensor_desc, tensor_size); | |||||
| } | |||||
| for (auto &p : shape_futures) { | for (auto &p : shape_futures) { | ||||
| auto idx = p.first; | auto idx = p.first; | ||||
| auto &future = p.second; | auto &future = p.second; | ||||
| RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); | RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); | ||||
| GeTensorDescPtr src_tensor_desc; | |||||
| GE_CHK_STATUS_RET_NOLOG(future.GetTensorDesc(src_tensor_desc)); | |||||
| const GeTensorDesc* src_tensor_desc = nullptr; | |||||
| GE_CHK_STATUS_RET_NOLOG(future.GetTensorDesc(&src_tensor_desc)); | |||||
| GE_CHECK_NOTNULL(src_tensor_desc); | GE_CHECK_NOTNULL(src_tensor_desc); | ||||
| RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); | RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); | ||||
| @@ -142,10 +164,28 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| ShapeFuture::ShapeFuture(NodePtr src_node, | |||||
| const vector<GeTensorDesc> &ShapeInferenceState::GetOutputTensorDesc() const { | |||||
| return output_tensor_desc; | |||||
| } | |||||
| Status ShapeInferenceState::UpdateOutputDesc() { | |||||
| for (size_t i = 0; i < output_tensor_desc.size(); ++i) { | |||||
| auto src_tensor_desc = node_item.MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(src_tensor_desc); | |||||
| auto &dst_tensor_desc = output_tensor_desc[i]; | |||||
| dst_tensor_desc.SetShape(src_tensor_desc->MutableShape()); | |||||
| dst_tensor_desc.SetOriginShape(src_tensor_desc->GetOriginShape()); | |||||
| int64_t tensor_size = -1; | |||||
| (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | |||||
| (void) TensorUtils::SetSize(dst_tensor_desc, tensor_size); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| ShapeFuture::ShapeFuture(NodeState *src_node, | |||||
| uint32_t src_index, | uint32_t src_index, | ||||
| SubgraphContext *subgraph_context) | SubgraphContext *subgraph_context) | ||||
| : src_node_(std::move(src_node)), src_index_(src_index), subgraph_context_(subgraph_context) { | |||||
| : src_node_(src_node), src_index_(src_index), subgraph_context_(subgraph_context) { | |||||
| } | } | ||||
| NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) | NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) | ||||
| @@ -187,6 +227,13 @@ Status NodeState::WaitForPrepareDone() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status NodeState::UpdateOutputShapes(int index, const GeShape &shape, const GeShape &ori_shape) { | |||||
| auto self_tensor_desc = op_desc_->MutableOutputDesc(index); | |||||
| GE_CHECK_NOTNULL(self_tensor_desc); | |||||
| self_tensor_desc->SetShape(shape); | |||||
| self_tensor_desc->SetOriginShape(ori_shape); | |||||
| return SUCCESS; | |||||
| } | |||||
| void NodeState::SetTaskContext(std::shared_ptr<TaskContext> &task_context) { | void NodeState::SetTaskContext(std::shared_ptr<TaskContext> &task_context) { | ||||
| task_context_ = task_context; | task_context_ = task_context; | ||||
| @@ -198,17 +245,19 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
| Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { | Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { | ||||
| GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | ||||
| HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_), "cancelled"); | |||||
| shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->MutableShape(); | |||||
| ori_shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->GetOriginShape(); | |||||
| HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_->GetNodeItem()->node), "cancelled"); | |||||
| auto &output_desc = src_node_->GetShapeInferenceState().GetOutputTensorDesc().at(src_index_); | |||||
| shape = output_desc.GetShape(); | |||||
| ori_shape = output_desc.GetOriginShape(); | |||||
| GELOGD("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); | GELOGD("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ShapeFuture::GetTensorDesc(GeTensorDescPtr &tensor_desc) { | |||||
| Status ShapeFuture::GetTensorDesc(const GeTensorDesc **tensor_desc) { | |||||
| GE_CHECK_NOTNULL(tensor_desc); | |||||
| GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | ||||
| HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_), "cancelled"); | |||||
| tensor_desc = src_node_->GetOpDesc()->MutableOutputDesc(src_index_); | |||||
| HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_->GetNodeItem()->node), "cancelled"); | |||||
| *tensor_desc = &src_node_->GetShapeInferenceState().GetOutputTensorDesc().at(src_index_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| @@ -30,16 +30,17 @@ class NodeTask; | |||||
| struct GraphExecutionContext; | struct GraphExecutionContext; | ||||
| class SubgraphContext; | class SubgraphContext; | ||||
| class TaskContext; | class TaskContext; | ||||
| class NodeState; | |||||
| class ShapeFuture { | class ShapeFuture { | ||||
| public: | public: | ||||
| ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context); | |||||
| ShapeFuture(NodeState *src_node, uint32_t src_index, SubgraphContext *subgraph_context); | |||||
| ~ShapeFuture() = default; | ~ShapeFuture() = default; | ||||
| Status Get(GeShape &ori_shape, GeShape &shape); | Status Get(GeShape &ori_shape, GeShape &shape); | ||||
| Status GetTensorDesc(GeTensorDescPtr &tensor_desc); | |||||
| Status GetTensorDesc(const GeTensorDesc **tensor_desc); | |||||
| private: | private: | ||||
| NodePtr src_node_; | |||||
| NodeState *src_node_; | |||||
| uint32_t src_index_; | uint32_t src_index_; | ||||
| SubgraphContext *subgraph_context_; | SubgraphContext *subgraph_context_; | ||||
| }; | }; | ||||
| @@ -53,10 +54,19 @@ struct ShapeInferenceState { | |||||
| Status AwaitShapesReady(const GraphExecutionContext &context); | Status AwaitShapesReady(const GraphExecutionContext &context); | ||||
| Status UpdateOutputDesc(); | |||||
| const vector<GeTensorDesc> &GetOutputTensorDesc() const; | |||||
| const NodeItem &node_item; | const NodeItem &node_item; | ||||
| private: | private: | ||||
| friend struct NodeState; | |||||
| std::vector<std::pair<int, ShapeFuture>> shape_futures; | std::vector<std::pair<int, ShapeFuture>> shape_futures; | ||||
| // do not directly update op_desc, in case race condition across pipelines | |||||
| std::vector<GeTensorDesc> input_tensor_desc; | |||||
| std::vector<GeTensorDesc> output_tensor_desc; | |||||
| int num_pending_shapes_ = 0; | int num_pending_shapes_ = 0; | ||||
| std::condition_variable ready_cv_; | std::condition_variable ready_cv_; | ||||
| std::mutex mu_; | std::mutex mu_; | ||||
| @@ -88,6 +98,8 @@ struct NodeState { | |||||
| return shape_inference_state_; | return shape_inference_state_; | ||||
| } | } | ||||
| Status UpdateOutputShapes(int index, const GeShape &shape, const GeShape &ori_shape); | |||||
| const shared_ptr<NodeTask> &GetKernelTask() const { | const shared_ptr<NodeTask> &GetKernelTask() const { | ||||
| return kernel_task_; | return kernel_task_; | ||||
| } | } | ||||
| @@ -21,14 +21,11 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| CallbackManager::CallbackManager(rtStream_t stream) : stream_(stream) { | |||||
| } | |||||
| Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) { | |||||
| Status CallbackManager::RegisterCallback(rtStream_t stream, rtCallback_t callback, void *user_data) { | |||||
| GELOGD("To register callback"); | GELOGD("To register callback"); | ||||
| rtEvent_t event = nullptr; | rtEvent_t event = nullptr; | ||||
| GE_CHK_RT_RET(rtEventCreate(&event)); | GE_CHK_RT_RET(rtEventCreate(&event)); | ||||
| auto rt_ret = rtEventRecord(event, stream_); | |||||
| auto rt_ret = rtEventRecord(event, stream); | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Failed to invoke rtEventRecord, error code = %d", rt_ret); | GELOGE(RT_FAILED, "Failed to invoke rtEventRecord, error code = %d", rt_ret); | ||||
| (void) rtEventDestroy(event); | (void) rtEventDestroy(event); | ||||
| @@ -112,11 +109,11 @@ void CallbackManager::RtCallbackFunc(void *data) { | |||||
| delete callback_func; | delete callback_func; | ||||
| } | } | ||||
| Status CallbackManager::RegisterCallback(const std::function<void()> &callback) { | |||||
| Status CallbackManager::RegisterCallback(rtStream_t stream, const std::function<void()> &callback) { | |||||
| auto func = std::unique_ptr<std::function<void()>>(new(std::nothrow) std::function<void()>(callback)); | auto func = std::unique_ptr<std::function<void()>>(new(std::nothrow) std::function<void()>(callback)); | ||||
| GE_CHECK_NOTNULL(func); | GE_CHECK_NOTNULL(func); | ||||
| GELOGD("Callback registered"); | GELOGD("Callback registered"); | ||||
| return RegisterCallback(RtCallbackFunc, func.release()); | |||||
| return RegisterCallback(stream, RtCallbackFunc, func.release()); | |||||
| } | } | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -30,23 +30,21 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| class CallbackManager { | class CallbackManager { | ||||
| public: | public: | ||||
| explicit CallbackManager(rtStream_t stream); | |||||
| CallbackManager() = default; | |||||
| ~CallbackManager() = default; | ~CallbackManager() = default; | ||||
| Status Init(); | Status Init(); | ||||
| Status Destroy(); | Status Destroy(); | ||||
| Status RegisterCallback(rtCallback_t callback, void *user_data); | |||||
| Status RegisterCallback(const std::function<void()> &callback); | |||||
| Status RegisterCallback(rtStream_t stream, rtCallback_t callback, void *user_data); | |||||
| Status RegisterCallback(rtStream_t stream, const std::function<void()> &callback); | |||||
| private: | private: | ||||
| Status CallbackProcess(rtContext_t context); | Status CallbackProcess(rtContext_t context); | ||||
| static void RtCallbackFunc(void *data); | static void RtCallbackFunc(void *data); | ||||
| BlockingQueue<std::pair<rtEvent_t, std::pair<rtCallback_t, void *>>> callback_queue_; | BlockingQueue<std::pair<rtEvent_t, std::pair<rtCallback_t, void *>>> callback_queue_; | ||||
| rtStream_t stream_; | |||||
| std::future<Status> ret_future_; | std::future<Status> ret_future_; | ||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| @@ -24,6 +24,7 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace { | namespace { | ||||
| constexpr int kDefaultThreadNum = 4; | constexpr int kDefaultThreadNum = 4; | ||||
| constexpr int kDefaultQueueSize = 16; | |||||
| constexpr int kDataInputIndex = 0; | constexpr int kDataInputIndex = 0; | ||||
| } | } | ||||
| @@ -31,7 +32,8 @@ SubgraphExecutor::SubgraphExecutor(const GraphItem *graph_item, GraphExecutionCo | |||||
| : graph_item_(graph_item), | : graph_item_(graph_item), | ||||
| context_(context), | context_(context), | ||||
| force_infer_shape_(force_infer_shape), | force_infer_shape_(force_infer_shape), | ||||
| pre_run_pool_(kDefaultThreadNum) { | |||||
| pre_run_pool_(kDefaultThreadNum), | |||||
| ready_queue_(kDefaultQueueSize) { | |||||
| } | } | ||||
| SubgraphExecutor::~SubgraphExecutor() { | SubgraphExecutor::~SubgraphExecutor() { | ||||
| @@ -169,7 +171,7 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue | |||||
| GE_CHECK_NOTNULL(node_state); | GE_CHECK_NOTNULL(node_state); | ||||
| node_state->SetKernelTask(node_item->kernel_task); | node_state->SetKernelTask(node_item->kernel_task); | ||||
| known_shape_task_context_ = TaskContext::Create(*node_item, context_, subgraph_context_.get()); | |||||
| known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(known_shape_task_context_); | GE_CHECK_NOTNULL(known_shape_task_context_); | ||||
| HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_), | HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_), | ||||
| @@ -201,11 +203,11 @@ Status SubgraphExecutor::ExecuteAsync(TaskContext &task_context) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status SubgraphExecutor::PrepareNodes() { | |||||
| GELOGD("[%s] Start to prepare nodes. force infer shape = %s.", | |||||
| Status SubgraphExecutor::PrepareNodes(int group) { | |||||
| GELOGD("[%s] Start to prepare nodes. group = %d", | |||||
| graph_item_->GetName().c_str(), | graph_item_->GetName().c_str(), | ||||
| force_infer_shape_ ? "true" : "false"); | |||||
| auto &all_nodes = graph_item_->GetAllNodes(); | |||||
| group); | |||||
| auto &all_nodes = graph_item_->GetAllNodes(group); | |||||
| for (auto all_node : all_nodes) { | for (auto all_node : all_nodes) { | ||||
| auto &node_item = *all_node; | auto &node_item = *all_node; | ||||
| // for while op | // for while op | ||||
| @@ -240,7 +242,8 @@ Status SubgraphExecutor::PrepareNodes() { | |||||
| } else { | } else { | ||||
| node_state->SetKernelTask(node_item.kernel_task); | node_state->SetKernelTask(node_item.kernel_task); | ||||
| } | } | ||||
| auto unique_task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get()); | |||||
| auto unique_task_context = | |||||
| TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(unique_task_context); | GE_CHECK_NOTNULL(unique_task_context); | ||||
| const auto &task = node_state->GetKernelTask(); | const auto &task = node_state->GetKernelTask(); | ||||
| if (task == nullptr) { | if (task == nullptr) { | ||||
| @@ -265,15 +268,17 @@ Status SubgraphExecutor::PrepareNodes() { | |||||
| GELOGD("[%s] Push node [%s] to queue.", graph_item_->GetName().c_str(), node_item.NodeName().c_str()); | GELOGD("[%s] Push node [%s] to queue.", graph_item_->GetName().c_str(), node_item.NodeName().c_str()); | ||||
| } | } | ||||
| GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status SubgraphExecutor::InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state) { | |||||
| const auto &node_item = *node_state.GetNodeItem(); | |||||
| Status SubgraphExecutor::InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state) const { | |||||
| GetContext().SetSessionId(context_->context_id); | |||||
| HYBRID_CHK_STATUS_RET(shape_inference_engine->InferShape(node_state), | HYBRID_CHK_STATUS_RET(shape_inference_engine->InferShape(node_state), | ||||
| "[%s] Failed to InferShape.", node_state.GetName().c_str()); | |||||
| HYBRID_CHK_STATUS_RET(shape_inference_engine->PropagateOutputShapes(node_item), | |||||
| "[%s] Failed to PropagateOutputShapes.", node_state.GetName().c_str()); | |||||
| "[%s] Failed to InferShape.", node_state.GetName().c_str()); | |||||
| GetContext().SetSessionId(context_->session_id); | |||||
| HYBRID_CHK_STATUS_RET(shape_inference_engine->PropagateOutputShapes(node_state), | |||||
| "[%s] Failed to PropagateOutputShapes.", node_state.GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -285,7 +290,7 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta | |||||
| } else { | } else { | ||||
| node_state.SetKernelTask(node_item.kernel_task); | node_state.SetKernelTask(node_item.kernel_task); | ||||
| } | } | ||||
| auto unique_task_context = TaskContext::Create(*node_state.GetNodeItem(), context_, subgraph_context_.get()); | |||||
| auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(unique_task_context); | GE_CHECK_NOTNULL(unique_task_context); | ||||
| const auto &task = node_state.GetKernelTask(); | const auto &task = node_state.GetKernelTask(); | ||||
| if (task == nullptr) { | if (task == nullptr) { | ||||
| @@ -336,11 +341,11 @@ Status SubgraphExecutor::LaunchTasks() { | |||||
| } | } | ||||
| } | } | ||||
| Status SubgraphExecutor::ScheduleTasks() { | |||||
| Status SubgraphExecutor::ScheduleTasks(int group) { | |||||
| GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | ||||
| auto prepare_future = std::async(std::launch::async, [&]() -> Status { | auto prepare_future = std::async(std::launch::async, [&]() -> Status { | ||||
| GetContext().SetSessionId(context_->session_id); | GetContext().SetSessionId(context_->session_id); | ||||
| auto ret = PrepareNodes(); | |||||
| auto ret = PrepareNodes(group); | |||||
| ready_queue_.Push(nullptr); | ready_queue_.Push(nullptr); | ||||
| return ret; | return ret; | ||||
| }); | }); | ||||
| @@ -481,5 +486,14 @@ Status SubgraphExecutor::EnableOutputZeroCopy(const vector<TensorValue> &outputs | |||||
| GELOGD("Done enabling zero copy for outputs successfully."); | GELOGD("Done enabling zero copy for outputs successfully."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status SubgraphExecutor::PartialExecuteAsync(int task_group) { | |||||
| return ScheduleTasks(task_group); | |||||
| } | |||||
| Status SubgraphExecutor::InitForPartialExecution(const vector<TensorValue> &inputs, | |||||
| const vector<ConstGeTensorDescPtr> &input_desc) { | |||||
| return Init(inputs, input_desc); | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -36,6 +36,11 @@ class SubgraphExecutor { | |||||
| SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false); | SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false); | ||||
| ~SubgraphExecutor(); | ~SubgraphExecutor(); | ||||
| Status InitForPartialExecution(const std::vector<TensorValue> &inputs, | |||||
| const std::vector<ConstGeTensorDescPtr> &input_desc); | |||||
| Status PartialExecuteAsync(int task_group); | |||||
| /** | /** | ||||
| * Execute subgraph async, output tensor address(not data) and output tensor descriptions are | * Execute subgraph async, output tensor address(not data) and output tensor descriptions are | ||||
| * valid after this method returned | * valid after this method returned | ||||
| @@ -89,15 +94,15 @@ class SubgraphExecutor { | |||||
| private: | private: | ||||
| Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state); | Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state); | ||||
| Status EnableOutputZeroCopy(const std::vector<TensorValue> &outputs); | Status EnableOutputZeroCopy(const std::vector<TensorValue> &outputs); | ||||
| static Status InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state); | |||||
| Status InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state) const; | |||||
| Status Init(const std::vector<TensorValue> &inputs, | Status Init(const std::vector<TensorValue> &inputs, | ||||
| const std::vector<ConstGeTensorDescPtr> &input_desc); | const std::vector<ConstGeTensorDescPtr> &input_desc); | ||||
| Status InitInputsForUnknownShape(const std::vector<TensorValue> &inputs, | Status InitInputsForUnknownShape(const std::vector<TensorValue> &inputs, | ||||
| const std::vector<ConstGeTensorDescPtr> &input_desc); | const std::vector<ConstGeTensorDescPtr> &input_desc); | ||||
| Status InitInputsForKnownShape(const std::vector<TensorValue> &inputs); | Status InitInputsForKnownShape(const std::vector<TensorValue> &inputs); | ||||
| Status ExecuteAsyncForKnownShape(const std::vector<TensorValue> &inputs); | Status ExecuteAsyncForKnownShape(const std::vector<TensorValue> &inputs); | ||||
| Status ScheduleTasks(); | |||||
| Status PrepareNodes(); | |||||
| Status ScheduleTasks(int group = -1); | |||||
| Status PrepareNodes(int group = -1); | |||||
| Status LaunchTasks(); | Status LaunchTasks(); | ||||
| Status SetOutputsToParentNode(TaskContext &task_context); | Status SetOutputsToParentNode(TaskContext &task_context); | ||||
| @@ -125,16 +125,16 @@ Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { | |||||
| RT_MEMCPY_DEVICE_TO_HOST)); | RT_MEMCPY_DEVICE_TO_HOST)); | ||||
| } | } | ||||
| tensor.SetData(std::move(host_buffer)); | tensor.SetData(std::move(host_buffer)); | ||||
| string session_id = std::to_string(context_->GetSessionId()); | |||||
| string context_id = std::to_string(graph_context_->context_id); | |||||
| RuntimeInferenceContext *runtime_infer_ctx = nullptr; | RuntimeInferenceContext *runtime_infer_ctx = nullptr; | ||||
| GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx), | |||||
| "Failed to get RuntimeInferenceContext, session_id = %s", session_id.c_str()); | |||||
| GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::GetContext(context_id, &runtime_infer_ctx), | |||||
| "Failed to get RuntimeInferenceContext, context_id = %s", context_id.c_str()); | |||||
| GE_CHK_STATUS_RET(runtime_infer_ctx->SetTensor(node_item.node_id, output_idx, std::move(tensor)), | GE_CHK_STATUS_RET(runtime_infer_ctx->SetTensor(node_item.node_id, output_idx, std::move(tensor)), | ||||
| "Failed to SetTensor, node = %s, output_index = %d", node_item.NodeName().c_str(), output_idx); | "Failed to SetTensor, node = %s, output_index = %d", node_item.NodeName().c_str(), output_idx); | ||||
| GELOGD("[%s] Output[%d] cached successfully in session: %s. node_id = %d, shape = [%s]", | |||||
| GELOGD("[%s] Output[%d] cached successfully in context: %s. node_id = %d, shape = [%s]", | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| output_idx, | output_idx, | ||||
| session_id.c_str(), | |||||
| context_id.c_str(), | |||||
| node_item.node_id, | node_item.node_id, | ||||
| ge_tensor_desc->GetShape().ToString().c_str()); | ge_tensor_desc->GetShape().ToString().c_str()); | ||||
| @@ -332,6 +332,7 @@ Status NodeDoneCallback::OnNodeDone() { | |||||
| if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE) { | if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE) { | ||||
| // update output tensor sizes | // update output tensor sizes | ||||
| GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(node_item)); | GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(node_item)); | ||||
| GE_CHK_STATUS_RET_NOLOG(context_->GetNodeState()->GetShapeInferenceState().UpdateOutputDesc()); | |||||
| } | } | ||||
| // PropagateOutputs for type == DEPEND_COMPUTE | // PropagateOutputs for type == DEPEND_COMPUTE | ||||
| if (node_item.shape_inference_type == DEPEND_COMPUTE) { | if (node_item.shape_inference_type == DEPEND_COMPUTE) { | ||||
| @@ -363,7 +364,7 @@ Status ExecutionEngine::ExecuteAsync(NodeState &node_state, | |||||
| RECORD_EXECUTION_EVENT(&execution_context, task_context->GetNodeName(), "Start"); | RECORD_EXECUTION_EVENT(&execution_context, task_context->GetNodeName(), "Start"); | ||||
| auto cb = std::shared_ptr<NodeDoneCallback>(new(std::nothrow) NodeDoneCallback(&execution_context, task_context)); | auto cb = std::shared_ptr<NodeDoneCallback>(new(std::nothrow) NodeDoneCallback(&execution_context, task_context)); | ||||
| GE_CHECK_NOTNULL(cb); | GE_CHECK_NOTNULL(cb); | ||||
| auto callback = [&, cb]() { | |||||
| auto callback = [task_context, cb]() { | |||||
| auto ret = cb->OnNodeDone(); | auto ret = cb->OnNodeDone(); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| task_context->OnError(ret); | task_context->OnError(ret); | ||||
| @@ -109,7 +109,8 @@ Status ShapeInferenceEngine::AwaitDependentNodes(NodeState &node_state) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { | |||||
| Status ShapeInferenceEngine::PropagateOutputShapes(NodeState &node_state) { | |||||
| auto &node_item = *node_state.GetNodeItem(); | |||||
| if (node_item.is_output_shape_static) { | if (node_item.is_output_shape_static) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -140,9 +141,8 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { | |||||
| // in case type 3 and 4, shape will be valid after computing is done | // in case type 3 and 4, shape will be valid after computing is done | ||||
| auto &infer_state = dst_node_state->GetShapeInferenceState(); | auto &infer_state = dst_node_state->GetShapeInferenceState(); | ||||
| if (shape_is_future) { | if (shape_is_future) { | ||||
| ShapeFuture future(node_item.node, i, subgraph_context_); | |||||
| infer_state.UpdateInputShapeFuture(dst_input_index_and_node.first, | |||||
| std::move(future)); | |||||
| ShapeFuture future(&node_state, i, subgraph_context_); | |||||
| infer_state.UpdateInputShapeFuture(dst_input_index_and_node.first, std::move(future)); | |||||
| } else { | } else { | ||||
| GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first, *output_desc)); | GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first, *output_desc)); | ||||
| } | } | ||||
| @@ -32,7 +32,7 @@ class ShapeInferenceEngine { | |||||
| Status InferShapeForSubgraph(const NodeItem &node_item, const FusedSubgraph &fused_subgraph); | Status InferShapeForSubgraph(const NodeItem &node_item, const FusedSubgraph &fused_subgraph); | ||||
| Status PropagateOutputShapes(const NodeItem &node_item); | |||||
| Status PropagateOutputShapes(NodeState &node_state); | |||||
| static Status CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range = false); | static Status CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range = false); | ||||
| @@ -30,6 +30,19 @@ const vector<NodeItem *> &hybrid::GraphItem::GetAllNodes() const { | |||||
| return node_items_; | return node_items_; | ||||
| } | } | ||||
| const vector<NodeItem *> &GraphItem::GetAllNodes(int group) const { | |||||
| if (group == -1) { | |||||
| return GetAllNodes(); | |||||
| } | |||||
| if (group >= static_cast<int>(grouped_node_items_.size())) { | |||||
| static vector<NodeItem *> empty_nodes; | |||||
| return empty_nodes; | |||||
| } | |||||
| return grouped_node_items_[group]; | |||||
| } | |||||
| const vector<const NodeItem *> &GraphItem::GetInputNodes() const { | const vector<const NodeItem *> &GraphItem::GetInputNodes() const { | ||||
| return input_nodes_; | return input_nodes_; | ||||
| } | } | ||||
| @@ -74,5 +87,28 @@ const NodeItem *GraphItem::GetOutputNode() const { | |||||
| const vector<std::pair<const NodeItem *, int>> &GraphItem::GetOutputEdges() const { | const vector<std::pair<const NodeItem *, int>> &GraphItem::GetOutputEdges() const { | ||||
| return output_edges_; | return output_edges_; | ||||
| } | } | ||||
| Status GraphItem::GroupNodes() { | |||||
| int last_group = INT32_MIN; | |||||
| std::set<int> seen_groups; | |||||
| for (auto node : node_items_) { | |||||
| int group = node->group; | |||||
| if (group != last_group) { | |||||
| if (seen_groups.find(group) != seen_groups.end()) { | |||||
| GELOGE(INTERNAL_ERROR, "Unordered node group found. node = %s, group = %d", node->NodeName().c_str(), group); | |||||
| return INTERNAL_ERROR; | |||||
| } else { | |||||
| last_group = group; | |||||
| seen_groups.insert(group); | |||||
| grouped_node_items_.emplace_back(std::vector<NodeItem *>()); | |||||
| } | |||||
| } | |||||
| GELOGD("Adding node [%s] to group %d", node->NodeName().c_str(), group); | |||||
| grouped_node_items_.back().emplace_back(node); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -26,7 +26,9 @@ class GraphItem { | |||||
| public: | public: | ||||
| GraphItem() = default; | GraphItem() = default; | ||||
| ~GraphItem(); | ~GraphItem(); | ||||
| Status GroupNodes(); | |||||
| const vector<NodeItem *> &GetAllNodes() const; | const vector<NodeItem *> &GetAllNodes() const; | ||||
| const vector<NodeItem *> &GetAllNodes(int group) const; | |||||
| const vector<const NodeItem *> &GetInputNodes() const; | const vector<const NodeItem *> &GetInputNodes() const; | ||||
| Status GetOutputDescList(std::vector<ConstGeTensorDescPtr> &output_desc_list) const; | Status GetOutputDescList(std::vector<ConstGeTensorDescPtr> &output_desc_list) const; | ||||
| const vector<std::pair<const NodeItem *, int>> &GetOutputEdges() const; | const vector<std::pair<const NodeItem *, int>> &GetOutputEdges() const; | ||||
| @@ -46,6 +48,10 @@ class GraphItem { | |||||
| name_ = name; | name_ = name; | ||||
| } | } | ||||
| size_t NumGroups() const { | |||||
| return grouped_node_items_.size(); | |||||
| } | |||||
| const NodeItem *GetOutputNode() const; | const NodeItem *GetOutputNode() const; | ||||
| bool IsDynamic() const; | bool IsDynamic() const; | ||||
| @@ -56,6 +62,7 @@ class GraphItem { | |||||
| friend class HybridModelBuilder; | friend class HybridModelBuilder; | ||||
| std::string name_; | std::string name_; | ||||
| std::vector<NodeItem *> node_items_; | std::vector<NodeItem *> node_items_; | ||||
| std::vector<std::vector<NodeItem *>> grouped_node_items_; | |||||
| std::vector<const NodeItem *> input_nodes_; | std::vector<const NodeItem *> input_nodes_; | ||||
| const NodeItem *output_node_ = nullptr; | const NodeItem *output_node_ = nullptr; | ||||
| // <src_node, out_index> | // <src_node, out_index> | ||||
| @@ -52,7 +52,7 @@ Status HybridModel::Init(bool is_single_op) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| TensorValue* HybridModel::GetVariable(const string &name) const { | |||||
| TensorValue *HybridModel::GetVariable(const string &name) const { | |||||
| auto it = variable_tensors_.find(name); | auto it = variable_tensors_.find(name); | ||||
| if (it == variable_tensors_.end()) { | if (it == variable_tensors_.end()) { | ||||
| GELOGD("Failed to get variable tensor. var name = [%s]", name.c_str()); | GELOGD("Failed to get variable tensor. var name = [%s]", name.c_str()); | ||||
| @@ -113,7 +113,7 @@ GeModelPtr HybridModel::GetGeModel(const NodePtr &node) const { | |||||
| return it->second; | return it->second; | ||||
| } | } | ||||
| const GraphItem* HybridModel::GetRootGraphItem() const { | |||||
| const GraphItem *HybridModel::GetRootGraphItem() const { | |||||
| return root_graph_item_.get(); | return root_graph_item_.get(); | ||||
| } | } | ||||
| @@ -287,6 +287,16 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
| src_node_item->NodeName().c_str()); | src_node_item->NodeName().c_str()); | ||||
| src_node_item->has_observer = true; | src_node_item->has_observer = true; | ||||
| node_item.dependents_for_execution.emplace_back(src_node); | node_item.dependents_for_execution.emplace_back(src_node); | ||||
| node_item.has_observer = true; | |||||
| for (auto &dst_node : ge_node->GetOutNodes()) { | |||||
| if (dst_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| NodeItem *dst_node_item = nullptr; | |||||
| GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(dst_node, &dst_node_item)); | |||||
| dst_node_item->dependents_for_execution.emplace_back(ge_node); | |||||
| } | |||||
| } else if (src_node_item->shape_inference_type == DEPEND_COMPUTE) { | } else if (src_node_item->shape_inference_type == DEPEND_COMPUTE) { | ||||
| GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE", | GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| @@ -614,6 +624,15 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (op_desc->HasAttr(ATTR_STAGE_LEVEL)) { | |||||
| uint32_t stage_level = UINT32_MAX; | |||||
| if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_STAGE_LEVEL, stage_level)) { | |||||
| for (const auto &stage_node : subgraph->GetAllNodes()) { | |||||
| GELOGD("Set ATTR_STAGE_LEVEL on node %s, stage_level=%u", stage_node->GetName().c_str(), stage_level); | |||||
| (void)AttrUtils::SetInt(stage_node->GetOpDesc(), ATTR_STAGE_LEVEL, stage_level); | |||||
| } | |||||
| } | |||||
| } | |||||
| GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), | GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), | ||||
| "[%s] Failed to merge subgraph.", | "[%s] Failed to merge subgraph.", | ||||
| subgraph->GetName().c_str()); | subgraph->GetName().c_str()); | ||||
| @@ -621,6 +640,14 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap | |||||
| // invoke before adding subgraphs. in case modify node id in known-shaped subgraphs. | // invoke before adding subgraphs. in case modify node id in known-shaped subgraphs. | ||||
| GE_CHK_GRAPH_STATUS_RET(merged_graph->TopologicalSorting(), "Failed to invoke TopologicalSorting on merged graph."); | GE_CHK_GRAPH_STATUS_RET(merged_graph->TopologicalSorting(), "Failed to invoke TopologicalSorting on merged graph."); | ||||
| GE_DUMP(merged_graph, "hybrid_merged_graph_BeforeStageSort"); | |||||
| merged_graph->TopologicalSorting([](const NodePtr &a, const NodePtr &b) -> bool { | |||||
| uint32_t a_level = UINT32_MAX; | |||||
| (void)AttrUtils::GetInt(a->GetOpDesc(), ATTR_STAGE_LEVEL, a_level); | |||||
| uint32_t b_level = UINT32_MAX; | |||||
| (void)AttrUtils::GetInt(b->GetOpDesc(), ATTR_STAGE_LEVEL, b_level); | |||||
| return a_level < b_level; | |||||
| }); | |||||
| for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { | for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { | ||||
| GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); | GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); | ||||
| @@ -675,41 +702,17 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, | |||||
| } | } | ||||
| Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, | Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, | ||||
| const NodeItem &node_item, | |||||
| bool is_root_graph) { | |||||
| auto output_size = node_item.num_inputs; | |||||
| graph_item.output_edges_.resize(output_size); | |||||
| for (auto &in_data_anchor : node_item.node->GetAllInDataAnchors()) { | |||||
| auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||||
| auto src_node = peer_out_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| auto src_node_item = GetNodeItem(src_node); | |||||
| GE_CHECK_NOTNULL(src_node_item); | |||||
| auto output_idx = in_data_anchor->GetIdx(); | |||||
| auto output_offset = src_node_item->output_start + peer_out_anchor->GetIdx(); | |||||
| GELOGI("Output[%d], node = %s, output_index = %d, output_offset = %d ", | |||||
| output_idx, | |||||
| src_node_item->NodeName().c_str(), | |||||
| peer_out_anchor->GetIdx(), | |||||
| output_offset); | |||||
| GE_CHECK_LE(output_idx, output_size - 1); | |||||
| graph_item.output_edges_[output_idx] = {src_node_item, peer_out_anchor->GetIdx()}; | |||||
| } | |||||
| if (!is_root_graph) { | |||||
| for (uint32_t i = 0; i < static_cast<uint32_t>(output_size); ++i) { | |||||
| uint32_t p_index = i; | |||||
| // Net output of Subgraph of while do not have parent index | |||||
| if (AttrUtils::GetInt(node_item.op_desc->GetInputDesc(i), ATTR_NAME_PARENT_NODE_INDEX, p_index)) { | |||||
| GELOGD("[%s] Parent index not set for input[%u].", node_item.NodeName().c_str(), i); | |||||
| } | |||||
| graph_item.output_index_mapping_.emplace_back(p_index); | |||||
| const NodeItem &node_item) { | |||||
| auto output_size = node_item.op_desc->GetAllInputsSize(); | |||||
| GE_CHECK_LE(output_size, UINT32_MAX); | |||||
| for (uint32_t i = 0; i < static_cast<uint32_t>(output_size); ++i) { | |||||
| uint32_t p_index = i; | |||||
| // Net output of Subgraph of while do not have parent index | |||||
| if (AttrUtils::GetInt(node_item.op_desc->GetInputDesc(i), ATTR_NAME_PARENT_NODE_INDEX, p_index)) { | |||||
| GELOGD("[%s] Parent index not set for input[%u].", node_item.NodeName().c_str(), i); | |||||
| } | } | ||||
| graph_item.output_index_mapping_.emplace_back(p_index); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -732,6 +735,7 @@ Status HybridModelBuilder::LoadGraph() { | |||||
| GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); | GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); | ||||
| GELOGD("Done loading root graph successfully."); | GELOGD("Done loading root graph successfully."); | ||||
| GE_CHK_STATUS_RET(hybrid_model_.root_graph_item_->GroupNodes(), "Failed to group nodes for root graph"); | |||||
| for (auto &sub_graph : root_graph->GetAllSubgraphs()) { | for (auto &sub_graph : root_graph->GetAllSubgraphs()) { | ||||
| GE_CHECK_NOTNULL(sub_graph); | GE_CHECK_NOTNULL(sub_graph); | ||||
| @@ -805,6 +809,7 @@ Status HybridModelBuilder::VarNodeToTensor(const NodePtr &var_node, std::unique_ | |||||
| // var size is only for checking, will not allocate any memory by it | // var size is only for checking, will not allocate any memory by it | ||||
| tensor.reset(new(std::nothrow)TensorValue(dev_mem, static_cast<size_t>(var_size))); | tensor.reset(new(std::nothrow)TensorValue(dev_mem, static_cast<size_t>(var_size))); | ||||
| GE_CHECK_NOTNULL(tensor); | GE_CHECK_NOTNULL(tensor); | ||||
| GELOGI("Get var memory addr %p for node %s, size = %lld, mem_type=%u", dev_mem, var_name.c_str(), var_size, mem_type); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1737,8 +1742,14 @@ Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, cons | |||||
| for (const auto &task_def : task_def_lists) { | for (const auto &task_def : task_def_lists) { | ||||
| hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | ||||
| } | } | ||||
| if (op_desc->HasAttr(ATTR_STAGE_LEVEL)) { | |||||
| uint32_t stage_level = UINT32_MAX; | |||||
| (void)ge::AttrUtils::GetInt(op_desc, ATTR_STAGE_LEVEL, stage_level); | |||||
| (void)ge::AttrUtils::SetInt(node_ptr->GetOpDesc(), ATTR_STAGE_LEVEL, stage_level); | |||||
| } | |||||
| NodeItem *node_item = nullptr; | NodeItem *node_item = nullptr; | ||||
| GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(profiling_node, &node_item)); | GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(profiling_node, &node_item)); | ||||
| GE_CHECK_NOTNULL(node_item); | |||||
| node_item->input_start = 0; | node_item->input_start = 0; | ||||
| node_item->output_start = 0; | node_item->output_start = 0; | ||||
| graph_item.node_items_.emplace_back(node_item); | graph_item.node_items_.emplace_back(node_item); | ||||
| @@ -1812,8 +1823,14 @@ Status HybridModelBuilder::CreateProfilingNodeAfter(GraphItem &graph_item, const | |||||
| for (const auto &task_def : task_def_lists) { | for (const auto &task_def : task_def_lists) { | ||||
| hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | ||||
| } | } | ||||
| if (op_desc->HasAttr(ATTR_STAGE_LEVEL)) { | |||||
| uint32_t stage_level = UINT32_MAX; | |||||
| (void)ge::AttrUtils::GetInt(op_desc, ATTR_STAGE_LEVEL, stage_level); | |||||
| (void)ge::AttrUtils::SetInt(profiling_node->GetOpDesc(), ATTR_STAGE_LEVEL, stage_level); | |||||
| } | |||||
| NodeItem *node_item = nullptr; | NodeItem *node_item = nullptr; | ||||
| GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(profiling_node, &node_item)); | GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(profiling_node, &node_item)); | ||||
| GE_CHECK_NOTNULL(node_item); | |||||
| node_item->input_start = 0; | node_item->input_start = 0; | ||||
| node_item->output_start = 0; | node_item->output_start = 0; | ||||
| graph_item.node_items_.emplace_back(node_item); | graph_item.node_items_.emplace_back(node_item); | ||||
| @@ -1859,7 +1876,9 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root | |||||
| data_nodes.emplace_back(node_item); | data_nodes.emplace_back(node_item); | ||||
| } else if (op_type == NETOUTPUT) { | } else if (op_type == NETOUTPUT) { | ||||
| graph_item->output_node_ = node_item; | graph_item->output_node_ = node_item; | ||||
| GE_CHK_STATUS_RET_NOLOG(BuildOutputMapping(*graph_item, *node_item, is_root_graph)); | |||||
| if (!is_root_graph) { | |||||
| GE_CHK_STATUS_RET_NOLOG(BuildOutputMapping(*graph_item, *node_item)); | |||||
| } | |||||
| } | } | ||||
| GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeBefore(*graph_item, node)); | GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeBefore(*graph_item, node)); | ||||
| graph_item->node_items_.emplace_back(node_item); | graph_item->node_items_.emplace_back(node_item); | ||||
| @@ -53,7 +53,7 @@ class HybridModelBuilder { | |||||
| std::vector<NodeItem *> &data_nodes, | std::vector<NodeItem *> &data_nodes, | ||||
| bool is_root_graph); | bool is_root_graph); | ||||
| static Status ResolveRefIo(NodeItem &node_item); | static Status ResolveRefIo(NodeItem &node_item); | ||||
| Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph); | |||||
| Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item); | |||||
| Status ValidateParams(); | Status ValidateParams(); | ||||
| Status LoadGraph(); | Status LoadGraph(); | ||||
| Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | ||||
| @@ -21,8 +21,8 @@ | |||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | |||||
| #include "hybrid/executor/worker/shape_inference_engine.h" | #include "hybrid/executor/worker/shape_inference_engine.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -146,6 +146,20 @@ Status NodeItem::InitInputsAndOutputs() { | |||||
| GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX); | GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX); | ||||
| num_inputs = static_cast<int>(op_desc->GetInputsSize()); | num_inputs = static_cast<int>(op_desc->GetInputsSize()); | ||||
| num_outputs = static_cast<int>(op_desc->GetOutputsSize()); | num_outputs = static_cast<int>(op_desc->GetOutputsSize()); | ||||
| if (AttrUtils::GetInt(op_desc, ::ge::ATTR_STAGE_LEVEL, group)) { | |||||
| GELOGD("[%s] Got stage level from op_desc = %d", op_desc->GetName().c_str(), group); | |||||
| } else { | |||||
| if (AttrUtils::GetInt(node->GetOwnerComputeGraph(), ::ge::ATTR_STAGE_LEVEL, group)) { | |||||
| GELOGD("[%s] Got stage level from parent graph = %d", op_desc->GetName().c_str(), group); | |||||
| } else { | |||||
| auto parent_node = node->GetOwnerComputeGraph()->GetParentNode(); | |||||
| if ((parent_node != nullptr) && (AttrUtils::GetInt(parent_node->GetOpDesc(), ::ge::ATTR_STAGE_LEVEL, group))) { | |||||
| GELOGD("[%s] Got stage level from parent node = %d", op_desc->GetName().c_str(), group); | |||||
| } else { | |||||
| GELOGD("[%s] Node do not set stage level", op_desc->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| ResolveOptionalInputs(); | ResolveOptionalInputs(); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -244,6 +258,7 @@ std::string NodeItem::DebugString() const { | |||||
| ss << ", is_dynamic = " << (is_dynamic ? "True" : "False"); | ss << ", is_dynamic = " << (is_dynamic ? "True" : "False"); | ||||
| ss << ", is_output_static = " << (is_output_shape_static ? "True" : "False"); | ss << ", is_output_static = " << (is_output_shape_static ? "True" : "False"); | ||||
| ss << ", unknown_shape_op_type = " << shape_inference_type; | ss << ", unknown_shape_op_type = " << shape_inference_type; | ||||
| ss << ", stage = " << group; | |||||
| ss << ", input_start = " << input_start; | ss << ", input_start = " << input_start; | ||||
| ss << ", num_inputs = " << num_inputs; | ss << ", num_inputs = " << num_inputs; | ||||
| ss << ", output_start = " << output_start; | ss << ", output_start = " << output_start; | ||||
| @@ -74,6 +74,7 @@ struct NodeItem { | |||||
| NodePtr node; | NodePtr node; | ||||
| OpDesc *op_desc; | OpDesc *op_desc; | ||||
| int node_id = -1; | int node_id = -1; | ||||
| int group = -1; | |||||
| int num_inputs = 0; | int num_inputs = 0; | ||||
| int num_outputs = 0; | int num_outputs = 0; | ||||
| int input_start = -1; | int input_start = -1; | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "hybrid/node_executor/aicore/aicore_op_task.h" | #include "hybrid/node_executor/aicore/aicore_op_task.h" | ||||
| #include "framework/common/taskdown_common.h" | #include "framework/common/taskdown_common.h" | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "graph/ge_context.h" | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| #include "hybrid/node_executor/aicore/aicore_task_builder.h" | #include "hybrid/node_executor/aicore/aicore_task_builder.h" | ||||
| #include "graph/load/model_manager/tbe_handle_store.h" | #include "graph/load/model_manager/tbe_handle_store.h" | ||||
| @@ -198,9 +199,12 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { | |||||
| tiling_info.clear_atomic = true; | tiling_info.clear_atomic = true; | ||||
| auto execution_context = context.GetExecutionContext(); | auto execution_context = context.GetExecutionContext(); | ||||
| GetContext().SetSessionId(execution_context->context_id); | |||||
| RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] Start"); | RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] Start"); | ||||
| GE_CHK_STATUS_RET(CalcTilingInfo(node, tiling_info)); | GE_CHK_STATUS_RET(CalcTilingInfo(node, tiling_info)); | ||||
| RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] End"); | RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] End"); | ||||
| GetContext().SetSessionId(execution_context->session_id); | |||||
| // update op args by tiling info | // update op args by tiling info | ||||
| block_dim_ = static_cast<uint32_t>(tiling_info.block_dim); | block_dim_ = static_cast<uint32_t>(tiling_info.block_dim); | ||||
| @@ -74,7 +74,7 @@ Status AicpuNodeTaskBase::InitExtInfo(const std::string &kernel_ext_info, int64_ | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status AicpuNodeTaskBase::UpdateOutputShapeFromExtInfo() { | |||||
| Status AicpuNodeTaskBase::UpdateOutputShapeFromExtInfo(TaskContext &task_context) { | |||||
| if (node_item_->num_outputs == 0) { | if (node_item_->num_outputs == 0) { | ||||
| GELOGD("Task [%s] output_num is 0, no need update output shape.", node_name_.c_str()); | GELOGD("Task [%s] output_num is 0, no need update output shape.", node_name_.c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -91,19 +91,19 @@ Status AicpuNodeTaskBase::UpdateOutputShapeFromExtInfo() { | |||||
| // not support update data type now, just for param | // not support update data type now, just for param | ||||
| DataType data_type; | DataType data_type; | ||||
| aicpu_ext_handle_.GetOutputShapeAndType(i, shape, data_type); | aicpu_ext_handle_.GetOutputShapeAndType(i, shape, data_type); | ||||
| auto output_desc = node_item_->MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| GE_CHK_STATUS_RET(UpdateShapeToOutputDesc(shape, i, output_desc), | |||||
| GE_CHK_STATUS_RET(UpdateShapeToOutputDesc(task_context, shape, i), | |||||
| "Update node %s [%d]th output shape failed.", | "Update node %s [%d]th output shape failed.", | ||||
| node_name_.c_str(), i); | node_name_.c_str(), i); | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status AicpuNodeTaskBase::UpdateShapeToOutputDesc(const GeShape &shape_new, | |||||
| int32_t output_index, GeTensorDescPtr &output_desc) { | |||||
| Status AicpuNodeTaskBase::UpdateShapeToOutputDesc(TaskContext &task_context, | |||||
| const GeShape &shape_new, | |||||
| int32_t output_index) { | |||||
| auto output_desc = task_context.MutableOutputDesc(output_index); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| auto shape_old = output_desc->GetShape(); | auto shape_old = output_desc->GetShape(); | ||||
| output_desc->SetShape(shape_new); | |||||
| GELOGD("Update node[%s] out[%d] shape from %s to %s.", node_name_.c_str(), output_index, | GELOGD("Update node[%s] out[%d] shape from %s to %s.", node_name_.c_str(), output_index, | ||||
| shape_old.ToString().c_str(), shape_new.ToString().c_str()); | shape_old.ToString().c_str(), shape_new.ToString().c_str()); | ||||
| @@ -111,9 +111,9 @@ Status AicpuNodeTaskBase::UpdateShapeToOutputDesc(const GeShape &shape_new, | |||||
| auto origin_format = output_desc->GetOriginFormat(); | auto origin_format = output_desc->GetOriginFormat(); | ||||
| auto format = output_desc->GetFormat(); | auto format = output_desc->GetFormat(); | ||||
| if (origin_format == format) { | if (origin_format == format) { | ||||
| output_desc->SetOriginShape(shape_new); | |||||
| return SUCCESS; | |||||
| return task_context.GetNodeState()->UpdateOutputShapes(output_index, shape_new, shape_new); | |||||
| } | } | ||||
| // if format is not same need convert shape | // if format is not same need convert shape | ||||
| std::vector<int64_t> origin_dims_new; | std::vector<int64_t> origin_dims_new; | ||||
| auto trans_ret = formats::TransShape(format, shape_new.GetDims(), | auto trans_ret = formats::TransShape(format, shape_new.GetDims(), | ||||
| @@ -122,7 +122,8 @@ Status AicpuNodeTaskBase::UpdateShapeToOutputDesc(const GeShape &shape_new, | |||||
| "Node[%s] out[%d] originFormat[%d] is not same as format[%d], but TransShape failed, shape=%s.", | "Node[%s] out[%d] originFormat[%d] is not same as format[%d], but TransShape failed, shape=%s.", | ||||
| node_name_.c_str(), output_index, origin_format, format, shape_new.ToString().c_str()); | node_name_.c_str(), output_index, origin_format, format, shape_new.ToString().c_str()); | ||||
| auto origin_shape_new = GeShape(origin_dims_new); | auto origin_shape_new = GeShape(origin_dims_new); | ||||
| output_desc->SetOriginShape(origin_shape_new); | |||||
| GE_CHK_STATUS_RET(task_context.GetNodeState()->UpdateOutputShapes(output_index, shape_new, origin_shape_new), | |||||
| "Node[%s] failed to update update shape, index = %d", node_name_.c_str(), output_index); | |||||
| GELOGD("Node[%s] out[%d] originFormat[%d] is not same as format[%d], need update from %s ro %s.", | GELOGD("Node[%s] out[%d] originFormat[%d] is not same as format[%d], need update from %s ro %s.", | ||||
| node_name_.c_str(), output_index, origin_format, format, | node_name_.c_str(), output_index, origin_format, format, | ||||
| origin_shape_old.ToString().c_str(), origin_shape_new.ToString().c_str()); | origin_shape_old.ToString().c_str(), origin_shape_new.ToString().c_str()); | ||||
| @@ -513,7 +514,6 @@ Status AicpuTfNodeTask::UpdateShapeByHbmBuffer(TaskContext &context, | |||||
| node_name_.c_str(), node_item_->num_outputs, out_shape_hbm.size()); | node_name_.c_str(), node_item_->num_outputs, out_shape_hbm.size()); | ||||
| for (auto i = 0; i < node_item_->num_outputs; ++i) { | for (auto i = 0; i < node_item_->num_outputs; ++i) { | ||||
| const auto &result_summary = output_summary_host_[i]; | const auto &result_summary = output_summary_host_[i]; | ||||
| auto output_desc = node_item_->MutableOutputDesc(i); | |||||
| std::vector<int64_t> shape_dims; | std::vector<int64_t> shape_dims; | ||||
| if (result_summary.shape_data_size > 0) { | if (result_summary.shape_data_size > 0) { | ||||
| const auto &shape_hbm = out_shape_hbm[i]; | const auto &shape_hbm = out_shape_hbm[i]; | ||||
| @@ -531,7 +531,7 @@ Status AicpuTfNodeTask::UpdateShapeByHbmBuffer(TaskContext &context, | |||||
| GELOGD("Node[%s] [%d]th output dim[%u]=%ld.", node_name_.c_str(), i, dim_idx, shape_addr[dim_idx]); | GELOGD("Node[%s] [%d]th output dim[%u]=%ld.", node_name_.c_str(), i, dim_idx, shape_addr[dim_idx]); | ||||
| } | } | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(UpdateShapeToOutputDesc(GeShape(shape_dims), i, output_desc), | |||||
| GE_CHK_STATUS_RET(UpdateShapeToOutputDesc(context, GeShape(shape_dims), i), | |||||
| "Node[%s] update [%d]th output shape failed.", | "Node[%s] update [%d]th output shape failed.", | ||||
| node_name_.c_str(), i); | node_name_.c_str(), i); | ||||
| } | } | ||||
| @@ -634,7 +634,7 @@ Status AicpuTfNodeTask::TaskCallback(TaskContext &context) { | |||||
| // check need update shape, call update shape. | // check need update shape, call update shape. | ||||
| if (unknown_type_ == DEPEND_SHAPE_RANGE) { | if (unknown_type_ == DEPEND_SHAPE_RANGE) { | ||||
| // check result | // check result | ||||
| callback_ret = UpdateOutputShapeFromExtInfo(); | |||||
| callback_ret = UpdateOutputShapeFromExtInfo(context); | |||||
| } else if (unknown_type_ == DEPEND_COMPUTE) { | } else if (unknown_type_ == DEPEND_COMPUTE) { | ||||
| callback_ret = UpdateShapeAndDataByResultSummary(context); | callback_ret = UpdateShapeAndDataByResultSummary(context); | ||||
| } | } | ||||
| @@ -781,7 +781,7 @@ Status AicpuNodeTask::TaskCallback(TaskContext &context) { | |||||
| // check need update shape, call update shape. | // check need update shape, call update shape. | ||||
| if (node_item_->is_dynamic && unknown_type_ == DEPEND_SHAPE_RANGE) { | if (node_item_->is_dynamic && unknown_type_ == DEPEND_SHAPE_RANGE) { | ||||
| // check result | // check result | ||||
| callback_ret = UpdateOutputShapeFromExtInfo(); | |||||
| callback_ret = UpdateOutputShapeFromExtInfo(context); | |||||
| } else { | } else { | ||||
| GELOGD("Node[%s] unknown shape type is %d no need update output shape.", | GELOGD("Node[%s] unknown shape type is %d no need update output shape.", | ||||
| node_name_.c_str(), unknown_type_); | node_name_.c_str(), unknown_type_); | ||||
| @@ -49,9 +49,9 @@ class AicpuNodeTaskBase : public NodeTask { | |||||
| virtual Status UpdateExtInfo(); | virtual Status UpdateExtInfo(); | ||||
| virtual Status UpdateOutputShapeFromExtInfo(); | |||||
| virtual Status UpdateOutputShapeFromExtInfo(TaskContext &task_context); | |||||
| Status UpdateShapeToOutputDesc(const GeShape &shape_new, int32_t output_index, GeTensorDescPtr &output_desc); | |||||
| Status UpdateShapeToOutputDesc(TaskContext &task_context, const GeShape &shape_new, int32_t output_index); | |||||
| virtual Status LaunchTask(TaskContext &context) = 0; | virtual Status LaunchTask(TaskContext &context) = 0; | ||||
| @@ -22,6 +22,8 @@ | |||||
| #include "graph/manager/util/hcom_util.h" | #include "graph/manager/util/hcom_util.h" | ||||
| #include "graph/runtime_inference_context.h" | #include "graph/runtime_inference_context.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/types.h" | |||||
| #include "hccl/hcom.h" | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -96,13 +98,13 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclRootId(op_desc, root_id), "GetHcclRootId failed"); | GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclRootId(op_desc, root_id), "GetHcclRootId failed"); | ||||
| } | } | ||||
| op_info.root = root_id; | op_info.root = root_id; | ||||
| auto callback = [this, op_desc](HcclResult status) { | |||||
| auto callback = [op_desc, done_callback](HcclResult status) { | |||||
| if (status != HCCL_SUCCESS) { | if (status != HCCL_SUCCESS) { | ||||
| GELOGE(HCCL_E_INTERNAL, "node %s call HcomExecEnqueueOperation failed, ret: 0x%X", | GELOGE(HCCL_E_INTERNAL, "node %s call HcomExecEnqueueOperation failed, ret: 0x%X", | ||||
| op_desc->GetName().c_str(), status); | op_desc->GetName().c_str(), status); | ||||
| } | } | ||||
| std::lock_guard<std::mutex> lock(this->hccl_mutex_); | |||||
| this->cond_.notify_all(); | |||||
| done_callback(); | |||||
| GELOGI("node %s hccl callback success.", op_desc->GetName().c_str()); | GELOGI("node %s hccl callback success.", op_desc->GetName().c_str()); | ||||
| }; | }; | ||||
| int32_t count = 0; | int32_t count = 0; | ||||
| @@ -119,11 +121,6 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| return HCCL_E_INTERNAL; | return HCCL_E_INTERNAL; | ||||
| } | } | ||||
| // pending until hccl finished | |||||
| std::unique_lock<std::mutex> ulock(hccl_mutex_); | |||||
| cond_.wait(ulock); | |||||
| GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); | |||||
| GELOGI("[%s] HcclNodeTask::ExecuteAsync success.", context.GetNodeName()); | GELOGI("[%s] HcclNodeTask::ExecuteAsync success.", context.GetNodeName()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -165,7 +162,8 @@ Status RdmaNodeTask::Init(TaskContext &context) { | |||||
| Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | ||||
| RuntimeInferenceContext *ctx = nullptr; | RuntimeInferenceContext *ctx = nullptr; | ||||
| GE_CHK_STATUS_RET(RuntimeInferenceContext::GetContext(std::to_string(context.GetSessionId()), &ctx)); | |||||
| GE_CHK_STATUS_RET( | |||||
| RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); | |||||
| ge::Tensor remote_tensor; | ge::Tensor remote_tensor; | ||||
| GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | ||||
| @@ -282,12 +280,13 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| auto callback = [this](HcclResult status) { | |||||
| TaskContext *p_ctx = &context; | |||||
| auto callback = [p_ctx, done_callback](HcclResult status) { | |||||
| if (status != HCCL_SUCCESS) { | if (status != HCCL_SUCCESS) { | ||||
| GELOGE(HCCL_E_INTERNAL, "Call HcomExecInitialize failed, ret: 0x%X", status); | |||||
| GELOGE(HCCL_E_INTERNAL, "Call HcomExcutorInitialize failed, ret: 0x%X", status); | |||||
| p_ctx->SetStatus(FAILED); | |||||
| } | } | ||||
| std::lock_guard<std::mutex> lock(this->hccl_mutex_); | |||||
| this->cond_.notify_all(); | |||||
| done_callback(); | |||||
| GELOGI("rdma callback success."); | GELOGI("rdma callback success."); | ||||
| }; | }; | ||||
| @@ -297,15 +296,10 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| } | } | ||||
| HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | ||||
| if (hccl_ret != HCCL_SUCCESS) { | if (hccl_ret != HCCL_SUCCESS) { | ||||
| GELOGE(HCCL_E_INTERNAL, "Call HcomExecInitialize failed, ret: 0x%X", hccl_ret); | |||||
| GELOGE(HCCL_E_INTERNAL, "Call HcomExcutorInitialize failed, ret: 0x%X", hccl_ret); | |||||
| return HCCL_E_INTERNAL; | return HCCL_E_INTERNAL; | ||||
| } | } | ||||
| // pending until hccl finished | |||||
| std::unique_lock<std::mutex> ulock(hccl_mutex_); | |||||
| cond_.wait(ulock); | |||||
| (void)context.RegisterCallback(done_callback); | |||||
| GELOGI("[%s] RdmaNodeTask::ExecuteAsync success.", context.GetNodeName()); | GELOGI("[%s] RdmaNodeTask::ExecuteAsync success.", context.GetNodeName()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -27,10 +27,12 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| TaskContext::TaskContext(GraphExecutionContext *execution_context, | TaskContext::TaskContext(GraphExecutionContext *execution_context, | ||||
| const NodeItem *node_item, | |||||
| NodeState *node_state, | |||||
| SubgraphContext *subgraph_context) | SubgraphContext *subgraph_context) | ||||
| : node_item_(node_item), execution_context_(execution_context), subgraph_context_(subgraph_context) { | |||||
| } | |||||
| : node_state_(node_state), | |||||
| node_item_(node_state->GetNodeItem()), | |||||
| execution_context_(execution_context), | |||||
| subgraph_context_(subgraph_context) {} | |||||
| TaskContext::~TaskContext() { | TaskContext::~TaskContext() { | ||||
| GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); | GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); | ||||
| @@ -47,9 +49,10 @@ TaskContext::~TaskContext() { | |||||
| } | } | ||||
| } | } | ||||
| std::unique_ptr<TaskContext> TaskContext::Create(const NodeItem &node_item, | |||||
| std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||||
| GraphExecutionContext *execution_context, | GraphExecutionContext *execution_context, | ||||
| SubgraphContext *subgraph_context) { | SubgraphContext *subgraph_context) { | ||||
| const NodeItem &node_item = *node_state->GetNodeItem(); | |||||
| GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", | GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| node_item.input_start, | node_item.input_start, | ||||
| @@ -65,7 +68,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(const NodeItem &node_item, | |||||
| } | } | ||||
| auto task_context = std::unique_ptr<TaskContext>( | auto task_context = std::unique_ptr<TaskContext>( | ||||
| new(std::nothrow)TaskContext(execution_context, &node_item, subgraph_context)); | |||||
| new(std::nothrow)TaskContext(execution_context, node_state, subgraph_context)); | |||||
| if (task_context == nullptr) { | if (task_context == nullptr) { | ||||
| GELOGE(MEMALLOC_FAILED, "[%s] Failed to create instance of TaskContext.", node_item.NodeName().c_str()); | GELOGE(MEMALLOC_FAILED, "[%s] Failed to create instance of TaskContext.", node_item.NodeName().c_str()); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -154,7 +157,7 @@ Status TaskContext::RegisterCallback(const std::function<void()> &callback_fun) | |||||
| GELOGW("[%s] Callback is NULL", GetNodeName()); | GELOGW("[%s] Callback is NULL", GetNodeName()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| auto ret = execution_context_->callback_manager->RegisterCallback(callback_fun); | |||||
| auto ret = execution_context_->callback_manager->RegisterCallback(GetStream(), callback_fun); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "[%s] Failed to register callback", GetNodeName()); | GELOGE(ret, "[%s] Failed to register callback", GetNodeName()); | ||||
| execution_context_->callback_manager->Destroy(); | execution_context_->callback_manager->Destroy(); | ||||
| @@ -309,7 +312,7 @@ Status TaskContext::SetOutput(int index, const TensorValue &tensor) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| rtStream_t TaskContext::GetStream() { | |||||
| rtStream_t TaskContext::GetStream() const { | |||||
| return execution_context_->stream; | return execution_context_->stream; | ||||
| } | } | ||||
| @@ -536,6 +539,10 @@ Status TaskContext::SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| NodeState *TaskContext::GetNodeState() const { | |||||
| return node_state_; | |||||
| } | |||||
| Status TaskContext::SaveProfilingGraphDescInfo(uint32_t task_id, uint32_t stream_id) { | Status TaskContext::SaveProfilingGraphDescInfo(uint32_t task_id, uint32_t stream_id) { | ||||
| if (ProfilingManager::Instance().ProfilingModelExecuteOn()) { | if (ProfilingManager::Instance().ProfilingModelExecuteOn()) { | ||||
| const NodeItem &node_item = GetNodeItem(); | const NodeItem &node_item = GetNodeItem(); | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "framework/common/ge_types.h" | #include "framework/common/ge_types.h" | ||||
| #include "hybrid/common/tensor_value.h" | #include "hybrid/common/tensor_value.h" | ||||
| #include "hybrid/common/npu_memory_allocator.h" | #include "hybrid/common/npu_memory_allocator.h" | ||||
| #include "hybrid/executor/node_state.h" | |||||
| #include "hybrid/executor/rt_callback_manager.h" | #include "hybrid/executor/rt_callback_manager.h" | ||||
| #include "hybrid/model/node_item.h" | #include "hybrid/model/node_item.h" | ||||
| @@ -35,7 +36,7 @@ class SubgraphContext; | |||||
| class TaskContext { | class TaskContext { | ||||
| public: | public: | ||||
| static std::unique_ptr<TaskContext> Create(const NodeItem &node_item, | |||||
| static std::unique_ptr<TaskContext> Create(NodeState *node_state, | |||||
| GraphExecutionContext *execution_context, | GraphExecutionContext *execution_context, | ||||
| SubgraphContext *subgraph_context); | SubgraphContext *subgraph_context); | ||||
| @@ -45,6 +46,7 @@ class TaskContext { | |||||
| int NumOutputs() const; | int NumOutputs() const; | ||||
| size_t NumWorkspaces() const; | size_t NumWorkspaces() const; | ||||
| const NodeItem &GetNodeItem() const; | const NodeItem &GetNodeItem() const; | ||||
| NodeState *GetNodeState() const; | |||||
| const char *GetNodeName() const; | const char *GetNodeName() const; | ||||
| TensorValue *MutableInput(int index); | TensorValue *MutableInput(int index); | ||||
| ConstGeTensorDescPtr GetInputDesc(int index) const; | ConstGeTensorDescPtr GetInputDesc(int index) const; | ||||
| @@ -58,7 +60,7 @@ class TaskContext { | |||||
| const TensorValue *GetOutput(int index) const; | const TensorValue *GetOutput(int index) const; | ||||
| TensorValue *MutableOutput(int index); | TensorValue *MutableOutput(int index); | ||||
| TensorValue *GetVariable(const std::string &name); | TensorValue *GetVariable(const std::string &name); | ||||
| rtStream_t GetStream(); | |||||
| rtStream_t GetStream() const; | |||||
| int64_t GetSessionId() const; | int64_t GetSessionId() const; | ||||
| uint64_t GetIterationNumber() const; | uint64_t GetIterationNumber() const; | ||||
| @@ -119,12 +121,13 @@ class TaskContext { | |||||
| private: | private: | ||||
| TaskContext(GraphExecutionContext *execution_context, | TaskContext(GraphExecutionContext *execution_context, | ||||
| const NodeItem *node_item, | |||||
| NodeState *node_state, | |||||
| SubgraphContext *subgraph_context); | SubgraphContext *subgraph_context); | ||||
| static string TensorDesc2String(const GeTensorDesc &desc); | static string TensorDesc2String(const GeTensorDesc &desc); | ||||
| Status AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor, AllocationAttr *attr); | Status AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor, AllocationAttr *attr); | ||||
| NodeState *node_state_ = nullptr; | |||||
| const NodeItem *node_item_ = nullptr; | const NodeItem *node_item_ = nullptr; | ||||
| bool force_infer_shape_ = false; | bool force_infer_shape_ = false; | ||||
| GraphExecutionContext *execution_context_; | GraphExecutionContext *execution_context_; | ||||