@@ -33,9 +33,6 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, | |||||
} | } | ||||
HybridModelExecutor::~HybridModelExecutor() { | HybridModelExecutor::~HybridModelExecutor() { | ||||
if (context_.rt_gen_context != nullptr) { | |||||
(void) rtCtxDestroy(context_.rt_gen_context); | |||||
} | |||||
} | } | ||||
Status HybridModelExecutor::Init() { | Status HybridModelExecutor::Init() { | ||||
@@ -139,7 +136,6 @@ Status HybridModelExecutor::Cleanup() { | |||||
Status HybridModelExecutor::InitExecutionContext() { | Status HybridModelExecutor::InitExecutionContext() { | ||||
GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | ||||
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | ||||
context_.global_step = model_->GetGlobalStep(); | context_.global_step = model_->GetGlobalStep(); | ||||
@@ -191,7 +191,6 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin | |||||
} | } | ||||
Status StageExecutor::InitExecutionContext() { | Status StageExecutor::InitExecutionContext() { | ||||
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | ||||
context_.model = model_; | context_.model = model_; | ||||
@@ -21,10 +21,17 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { | Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { | ||||
const auto &node_item = *node_state.GetNodeItem(); | |||||
GE_CHECK_NOTNULL(context); | GE_CHECK_NOTNULL(context); | ||||
rtContext_t rt_gen_context = nullptr; | |||||
GE_CHK_RT_RET(rtCtxCreate(&rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
std::function<void()> callback = [&]() { | |||||
(void) rtCtxDestroy(rt_gen_context); | |||||
GE_CHK_RT(rtCtxSetCurrent(context->rt_context)); | |||||
}; | |||||
GE_MAKE_GUARD(rt_gen_context, callback); | |||||
const auto &node_item = *node_state.GetNodeItem(); | |||||
RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); | ||||
GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); | |||||
if (context->ge_context != nullptr) { | if (context->ge_context != nullptr) { | ||||
GetThreadLocalContext() = *context->ge_context; | GetThreadLocalContext() = *context->ge_context; | ||||
@@ -1 +1 @@ | |||||
Subproject commit f3f137de034885f0c7394d7f04b41b08d450d2d2 | |||||
Subproject commit 9e4a51a9602195b82e326b853f5adbfefc3972b6 |
@@ -1 +1 @@ | |||||
Subproject commit 15a27afefe45f2abdb78787d629163aab9437599 | |||||
Subproject commit 79536a196f89cf7a1f5852ff7304b9a7d7b12eff |
@@ -27,6 +27,7 @@ | |||||
#include "hybrid/executor/hybrid_model_executor.h" | #include "hybrid/executor/hybrid_model_executor.h" | ||||
#include "hybrid/executor/worker/execution_engine.h" | #include "hybrid/executor/worker/execution_engine.h" | ||||
#include "hybrid/executor/subgraph_executor.h" | #include "hybrid/executor/subgraph_executor.h" | ||||
#include "hybrid/executor/worker/task_compile_engine.h" | |||||
#undef private | #undef private | ||||
#undef protected | #undef protected | ||||
@@ -45,7 +46,14 @@ class UtestExecutionEngine : public testing::Test { | |||||
}; | }; | ||||
namespace { | namespace { | ||||
const int kIntBase = 10; | const int kIntBase = 10; | ||||
class CompileNodeExecutor : public NodeExecutor { | |||||
public: | |||||
Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr<NodeTask> &task) const override { | |||||
return SUCCESS; | |||||
} | |||||
}; | |||||
} | } | ||||
static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { | static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { | ||||
auto op_desc = std::make_shared<ge::OpDesc>(name, type); | auto op_desc = std::make_shared<ge::OpDesc>(name, type); | ||||
op_desc->SetStreamId(0); | op_desc->SetStreamId(0); | ||||
@@ -128,4 +136,8 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||||
executor.InitCallback(node_state.get(), callback); | executor.InitCallback(node_state.get(), callback); | ||||
ExecutionEngine execution_engine; | ExecutionEngine execution_engine; | ||||
EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | ||||
CompileNodeExecutor node_executor; | |||||
node_item->node_executor = &node_executor; | |||||
EXPECT_EQ(TaskCompileEngine::Compile(*node_state, &execution_context), SUCCESS); | |||||
} | } |