@@ -22,6 +22,7 @@ | |||||
#include "common/blocking_queue.h" | #include "common/blocking_queue.h" | ||||
#include "common/properties_manager.h" | #include "common/properties_manager.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/ge_local_context.h" | |||||
#include "hybrid/common/npu_memory_allocator.h" | #include "hybrid/common/npu_memory_allocator.h" | ||||
#include "hybrid/common/tensor_value.h" | #include "hybrid/common/tensor_value.h" | ||||
#include "hybrid/executor/hybrid_profiler.h" | #include "hybrid/executor/hybrid_profiler.h" | ||||
@@ -38,6 +39,7 @@ struct GraphExecutionContext { | |||||
uint64_t session_id = 0; | uint64_t session_id = 0; | ||||
const HybridModel *model = nullptr; | const HybridModel *model = nullptr; | ||||
const GEThreadLocalContext *ge_context = nullptr; | |||||
rtStream_t stream = nullptr; | rtStream_t stream = nullptr; | ||||
rtContext_t rt_context = nullptr; | rtContext_t rt_context = nullptr; | ||||
rtContext_t rt_gen_context = nullptr; | rtContext_t rt_gen_context = nullptr; | ||||
@@ -95,6 +95,7 @@ Status HybridModelExecutor::InitExecutionContext() { | |||||
context_.stream = stream_; | context_.stream = stream_; | ||||
context_.model = model_; | context_.model = model_; | ||||
context_.session_id = ::ge::GetContext().SessionId(); | context_.session_id = ::ge::GetContext().SessionId(); | ||||
context_.ge_context = &GetThreadLocalContext(); | |||||
GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); | GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); | ||||
context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_); | context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_); | ||||
GE_CHECK_NOTNULL(context_.allocator); | GE_CHECK_NOTNULL(context_.allocator); | ||||
@@ -26,6 +26,9 @@ Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext * | |||||
RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | ||||
GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); | ||||
if (context->ge_context != nullptr) { | |||||
GetThreadLocalContext() = *context->ge_context; | |||||
} | |||||
shared_ptr<NodeTask> kernel_task; | shared_ptr<NodeTask> kernel_task; | ||||
auto ret = node_item.node_executor->CompileTask(*context->model, node_item.node, kernel_task); | auto ret = node_item.node_executor->CompileTask(*context->model, node_item.node, kernel_task); | ||||
RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "[Compile] End"); | RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "[Compile] End"); | ||||