Merge pull request !1920 from 赵之轩/my_devtags/v1.5.1
| @@ -33,9 +33,6 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, | |||||
| } | } | ||||
| HybridModelExecutor::~HybridModelExecutor() { | HybridModelExecutor::~HybridModelExecutor() { | ||||
| if (context_.rt_gen_context != nullptr) { | |||||
| (void) rtCtxDestroy(context_.rt_gen_context); | |||||
| } | |||||
| } | } | ||||
| Status HybridModelExecutor::Init() { | Status HybridModelExecutor::Init() { | ||||
| @@ -139,7 +136,6 @@ Status HybridModelExecutor::Cleanup() { | |||||
| Status HybridModelExecutor::InitExecutionContext() { | Status HybridModelExecutor::InitExecutionContext() { | ||||
| GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | ||||
| GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | ||||
| context_.global_step = model_->GetGlobalStep(); | context_.global_step = model_->GetGlobalStep(); | ||||
| @@ -191,7 +191,6 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin | |||||
| } | } | ||||
| Status StageExecutor::InitExecutionContext() { | Status StageExecutor::InitExecutionContext() { | ||||
| GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | ||||
| context_.model = model_; | context_.model = model_; | ||||
| @@ -21,10 +21,17 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { | Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { | ||||
| const auto &node_item = *node_state.GetNodeItem(); | |||||
| GE_CHECK_NOTNULL(context); | GE_CHECK_NOTNULL(context); | ||||
| rtContext_t rt_gen_context = nullptr; | |||||
| GE_CHK_RT_RET(rtCtxCreate(&rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
| std::function<void()> callback = [&]() { | |||||
| (void) rtCtxDestroy(rt_gen_context); | |||||
| GE_CHK_RT(rtCtxSetCurrent(context->rt_context)); | |||||
| }; | |||||
| GE_MAKE_GUARD(rt_gen_context, callback); | |||||
| const auto &node_item = *node_state.GetNodeItem(); | |||||
| RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | ||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); | |||||
| if (context->ge_context != nullptr) { | if (context->ge_context != nullptr) { | ||||
| GetThreadLocalContext() = *context->ge_context; | GetThreadLocalContext() = *context->ge_context; | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "hybrid/executor/hybrid_model_executor.h" | #include "hybrid/executor/hybrid_model_executor.h" | ||||
| #include "hybrid/executor/worker/execution_engine.h" | #include "hybrid/executor/worker/execution_engine.h" | ||||
| #include "hybrid/executor/subgraph_executor.h" | #include "hybrid/executor/subgraph_executor.h" | ||||
| #include "hybrid/executor/worker/task_compile_engine.h" | |||||
| #undef private | #undef private | ||||
| #undef protected | #undef protected | ||||
| @@ -45,7 +46,14 @@ class UtestExecutionEngine : public testing::Test { | |||||
| }; | }; | ||||
| namespace { | namespace { | ||||
| const int kIntBase = 10; | const int kIntBase = 10; | ||||
| class CompileNodeExecutor : public NodeExecutor { | |||||
| public: | |||||
| Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr<NodeTask> &task) const override { | |||||
| return SUCCESS; | |||||
| } | |||||
| }; | |||||
| } | } | ||||
| static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { | static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { | ||||
| auto op_desc = std::make_shared<ge::OpDesc>(name, type); | auto op_desc = std::make_shared<ge::OpDesc>(name, type); | ||||
| op_desc->SetStreamId(0); | op_desc->SetStreamId(0); | ||||
| @@ -128,4 +136,8 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||||
| executor.InitCallback(node_state.get(), callback); | executor.InitCallback(node_state.get(), callback); | ||||
| ExecutionEngine execution_engine; | ExecutionEngine execution_engine; | ||||
| EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | ||||
| CompileNodeExecutor node_executor; | |||||
| node_item->node_executor = &node_executor; | |||||
| EXPECT_EQ(TaskCompileEngine::Compile(*node_state, &execution_context), SUCCESS); | |||||
| } | } | ||||