Browse Source

!679 Fixing runtime compile options

From: @xchu42
Reviewed-by: @ji_chen,@wqtshg
Signed-off-by: @ji_chen
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
77b2f4b6de
3 changed files with 6 additions and 0 deletions
  1. +2
    -0
      ge/hybrid/executor/hybrid_execution_context.h
  2. +1
    -0
      ge/hybrid/executor/hybrid_model_executor.cc
  3. +3
    -0
      ge/hybrid/executor/worker/task_compile_engine.cc

+ 2
- 0
ge/hybrid/executor/hybrid_execution_context.h View File

@@ -22,6 +22,7 @@
#include "common/blocking_queue.h"
#include "common/properties_manager.h"
#include "framework/common/debug/ge_log.h"
#include "graph/ge_local_context.h"
#include "hybrid/common/npu_memory_allocator.h"
#include "hybrid/common/tensor_value.h"
#include "hybrid/executor/hybrid_profiler.h"
@@ -38,6 +39,7 @@ struct GraphExecutionContext {

uint64_t session_id = 0;
const HybridModel *model = nullptr;
const GEThreadLocalContext *ge_context = nullptr;
rtStream_t stream = nullptr;
rtContext_t rt_context = nullptr;
rtContext_t rt_gen_context = nullptr;


+ 1
- 0
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -95,6 +95,7 @@ Status HybridModelExecutor::InitExecutionContext() {
context_.stream = stream_;
context_.model = model_;
context_.session_id = ::ge::GetContext().SessionId();
context_.ge_context = &GetThreadLocalContext();
GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id);
context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_);
GE_CHECK_NOTNULL(context_.allocator);


+ 3
- 0
ge/hybrid/executor/worker/task_compile_engine.cc View File

@@ -26,6 +26,9 @@ Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *
RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start");
GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context));

if (context->ge_context != nullptr) {
GetThreadLocalContext() = *context->ge_context;
}
shared_ptr<NodeTask> kernel_task;
auto ret = node_item.node_executor->CompileTask(*context->model, node_item.node, kernel_task);
RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "[Compile] End");


Loading…
Cancel
Save