diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc index 58da451c..2bb683c7 100755 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -33,9 +33,6 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, } HybridModelExecutor::~HybridModelExecutor() { - if (context_.rt_gen_context != nullptr) { - (void) rtCtxDestroy(context_.rt_gen_context); - } } Status HybridModelExecutor::Init() { @@ -139,7 +136,6 @@ Status HybridModelExecutor::Cleanup() { Status HybridModelExecutor::InitExecutionContext() { GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); - GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); context_.global_step = model_->GetGlobalStep(); diff --git a/ge/hybrid/executor/hybrid_model_pipeline_executor.cc b/ge/hybrid/executor/hybrid_model_pipeline_executor.cc index 45e61138..b5e66628 100644 --- a/ge/hybrid/executor/hybrid_model_pipeline_executor.cc +++ b/ge/hybrid/executor/hybrid_model_pipeline_executor.cc @@ -191,7 +191,6 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin } Status StageExecutor::InitExecutionContext() { - GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); context_.model = model_; diff --git a/ge/hybrid/executor/worker/task_compile_engine.cc b/ge/hybrid/executor/worker/task_compile_engine.cc index f7da9acd..491e0997 100755 --- a/ge/hybrid/executor/worker/task_compile_engine.cc +++ b/ge/hybrid/executor/worker/task_compile_engine.cc @@ -21,10 +21,17 @@ namespace ge { namespace hybrid { Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { - const auto &node_item = *node_state.GetNodeItem(); GE_CHECK_NOTNULL(context); + rtContext_t rt_gen_context = nullptr; + GE_CHK_RT_RET(rtCtxCreate(&rt_gen_context, RT_CTX_GEN_MODE, 0)); + std::function callback = [&]() { + (void) rtCtxDestroy(rt_gen_context); + GE_CHK_RT(rtCtxSetCurrent(context->rt_context)); + }; + GE_MAKE_GUARD(rt_gen_context, callback); + + const auto &node_item = *node_state.GetNodeItem(); RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); - GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); if (context->ge_context != nullptr) { GetThreadLocalContext() = *context->ge_context; diff --git a/metadef b/metadef index f3f137de..9e4a51a9 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit f3f137de034885f0c7394d7f04b41b08d450d2d2 +Subproject commit 9e4a51a9602195b82e326b853f5adbfefc3972b6 diff --git a/parser b/parser index 15a27afe..79536a19 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit 15a27afefe45f2abdb78787d629163aab9437599 +Subproject commit 79536a196f89cf7a1f5852ff7304b9a7d7b12eff diff --git a/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc b/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc index 07701f4d..96641c59 100644 --- a/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc +++ b/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc @@ -27,6 +27,7 @@ #include "hybrid/executor/hybrid_model_executor.h" #include "hybrid/executor/worker/execution_engine.h" #include "hybrid/executor/subgraph_executor.h" +#include "hybrid/executor/worker/task_compile_engine.h" #undef private #undef protected @@ -45,7 +46,14 @@ class UtestExecutionEngine : public testing::Test { }; namespace { const int kIntBase = 10; +class CompileNodeExecutor : public NodeExecutor { + public: + Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const override { + return SUCCESS; + } +}; } + static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { auto op_desc = std::make_shared(name, type); op_desc->SetStreamId(0); @@ -128,4 +136,8 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { executor.InitCallback(node_state.get(), callback); ExecutionEngine execution_engine; EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); + + CompileNodeExecutor node_executor; + node_item->node_executor = &node_executor; + EXPECT_EQ(TaskCompileEngine::Compile(*node_state, &execution_context), SUCCESS); }