| @@ -33,9 +33,6 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, | |||||
| } | } | ||||
| HybridModelExecutor::~HybridModelExecutor() { | HybridModelExecutor::~HybridModelExecutor() { | ||||
| if (context_.rt_gen_context != nullptr) { | |||||
| (void) rtCtxDestroy(context_.rt_gen_context); | |||||
| } | |||||
| } | } | ||||
| Status HybridModelExecutor::Init() { | Status HybridModelExecutor::Init() { | ||||
| @@ -139,7 +136,6 @@ Status HybridModelExecutor::Cleanup() { | |||||
| Status HybridModelExecutor::InitExecutionContext() { | Status HybridModelExecutor::InitExecutionContext() { | ||||
| GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | ||||
| GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | ||||
| context_.global_step = model_->GetGlobalStep(); | context_.global_step = model_->GetGlobalStep(); | ||||
| @@ -175,7 +175,6 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin | |||||
| } | } | ||||
| Status StageExecutor::InitExecutionContext() { | Status StageExecutor::InitExecutionContext() { | ||||
| GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | ||||
| context_.model = model_; | context_.model = model_; | ||||
| @@ -21,10 +21,17 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { | Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { | ||||
| const auto &node_item = *node_state.GetNodeItem(); | |||||
| GE_CHECK_NOTNULL(context); | GE_CHECK_NOTNULL(context); | ||||
| rtContext_t rt_gen_context = nullptr; | |||||
| GE_CHK_RT_RET(rtCtxCreate(&rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
| std::function<void()> callback = [&]() { | |||||
| (void) rtCtxDestroy(rt_gen_context); | |||||
| GE_CHK_RT(rtCtxSetCurrent(context->rt_context)); | |||||
| }; | |||||
| GE_MAKE_GUARD(rt_gen_context, callback); | |||||
| const auto &node_item = *node_state.GetNodeItem(); | |||||
| RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | ||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); | |||||
| if (context->ge_context != nullptr) { | if (context->ge_context != nullptr) { | ||||
| GetThreadLocalContext() = *context->ge_context; | GetThreadLocalContext() = *context->ge_context; | ||||