From 1e9fdaebb62d846ebd1729e17a17c7cde59946f0 Mon Sep 17 00:00:00 2001 From: chuxing Date: Mon, 21 Dec 2020 16:35:48 +0800 Subject: [PATCH] fixing runtime compile options --- ge/hybrid/executor/hybrid_execution_context.h | 2 ++ ge/hybrid/executor/hybrid_model_executor.cc | 1 + ge/hybrid/executor/worker/task_compile_engine.cc | 3 +++ 3 files changed, 6 insertions(+) diff --git a/ge/hybrid/executor/hybrid_execution_context.h b/ge/hybrid/executor/hybrid_execution_context.h index 1fe40c77..f1c25290 100644 --- a/ge/hybrid/executor/hybrid_execution_context.h +++ b/ge/hybrid/executor/hybrid_execution_context.h @@ -22,6 +22,7 @@ #include "common/blocking_queue.h" #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" +#include "graph/ge_local_context.h" #include "hybrid/common/npu_memory_allocator.h" #include "hybrid/common/tensor_value.h" #include "hybrid/executor/hybrid_profiler.h" @@ -38,6 +39,7 @@ struct GraphExecutionContext { uint64_t session_id = 0; const HybridModel *model = nullptr; + const GEThreadLocalContext *ge_context = nullptr; rtStream_t stream = nullptr; rtContext_t rt_context = nullptr; rtContext_t rt_gen_context = nullptr; diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc index 8ba687c2..e17998db 100755 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -95,6 +95,7 @@ Status HybridModelExecutor::InitExecutionContext() { context_.stream = stream_; context_.model = model_; context_.session_id = ::ge::GetContext().SessionId(); + context_.ge_context = &GetThreadLocalContext(); GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_); GE_CHECK_NOTNULL(context_.allocator); diff --git a/ge/hybrid/executor/worker/task_compile_engine.cc b/ge/hybrid/executor/worker/task_compile_engine.cc index e2e94f66..f80374c6 100755 --- a/ge/hybrid/executor/worker/task_compile_engine.cc +++ b/ge/hybrid/executor/worker/task_compile_engine.cc @@ -26,6 +26,9 @@ Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext * RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); + if (context->ge_context != nullptr) { + GetThreadLocalContext() = *context->ge_context; + } shared_ptr kernel_task; auto ret = node_item.node_executor->CompileTask(*context->model, node_item.node, kernel_task); RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "[Compile] End");