Browse Source

Do not create context in hydrid executor init func.

tags/v1.5.1
zhaozhixuan 3 years ago
parent
commit
2400e65904
6 changed files with 23 additions and 9 deletions
  1. +0
    -4
      ge/hybrid/executor/hybrid_model_executor.cc
  2. +0
    -1
      ge/hybrid/executor/hybrid_model_pipeline_executor.cc
  3. +9
    -2
      ge/hybrid/executor/worker/task_compile_engine.cc
  4. +1
    -1
      metadef
  5. +1
    -1
      parser
  6. +12
    -0
      tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc

+ 0
- 4
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -33,9 +33,6 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id,
} }


HybridModelExecutor::~HybridModelExecutor() { HybridModelExecutor::~HybridModelExecutor() {
if (context_.rt_gen_context != nullptr) {
(void) rtCtxDestroy(context_.rt_gen_context);
}
} }


Status HybridModelExecutor::Init() { Status HybridModelExecutor::Init() {
@@ -139,7 +136,6 @@ Status HybridModelExecutor::Cleanup() {


Status HybridModelExecutor::InitExecutionContext() { Status HybridModelExecutor::InitExecutionContext() {
GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context));
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0));
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context));


context_.global_step = model_->GetGlobalStep(); context_.global_step = model_->GetGlobalStep();


+ 0
- 1
ge/hybrid/executor/hybrid_model_pipeline_executor.cc View File

@@ -191,7 +191,6 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin
} }


Status StageExecutor::InitExecutionContext() { Status StageExecutor::InitExecutionContext() {
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0));
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context));


context_.model = model_; context_.model = model_;


+ 9
- 2
ge/hybrid/executor/worker/task_compile_engine.cc View File

@@ -21,10 +21,17 @@
namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) {
const auto &node_item = *node_state.GetNodeItem();
GE_CHECK_NOTNULL(context); GE_CHECK_NOTNULL(context);
rtContext_t rt_gen_context = nullptr;
GE_CHK_RT_RET(rtCtxCreate(&rt_gen_context, RT_CTX_GEN_MODE, 0));
std::function<void()> callback = [&]() {
(void) rtCtxDestroy(rt_gen_context);
GE_CHK_RT(rtCtxSetCurrent(context->rt_context));
};
GE_MAKE_GUARD(rt_gen_context, callback);

const auto &node_item = *node_state.GetNodeItem();
RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start");
GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context));


if (context->ge_context != nullptr) { if (context->ge_context != nullptr) {
GetThreadLocalContext() = *context->ge_context; GetThreadLocalContext() = *context->ge_context;


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit f3f137de034885f0c7394d7f04b41b08d450d2d2
Subproject commit 9e4a51a9602195b82e326b853f5adbfefc3972b6

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit 15a27afefe45f2abdb78787d629163aab9437599
Subproject commit 79536a196f89cf7a1f5852ff7304b9a7d7b12eff

+ 12
- 0
tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc View File

@@ -27,6 +27,7 @@
#include "hybrid/executor/hybrid_model_executor.h" #include "hybrid/executor/hybrid_model_executor.h"
#include "hybrid/executor/worker/execution_engine.h" #include "hybrid/executor/worker/execution_engine.h"
#include "hybrid/executor/subgraph_executor.h" #include "hybrid/executor/subgraph_executor.h"
#include "hybrid/executor/worker/task_compile_engine.h"
#undef private #undef private
#undef protected #undef protected


@@ -45,7 +46,14 @@ class UtestExecutionEngine : public testing::Test {
}; };
namespace { namespace {
const int kIntBase = 10; const int kIntBase = 10;
class CompileNodeExecutor : public NodeExecutor {
public:
Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr<NodeTask> &task) const override {
return SUCCESS;
}
};
} }

static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") {
auto op_desc = std::make_shared<ge::OpDesc>(name, type); auto op_desc = std::make_shared<ge::OpDesc>(name, type);
op_desc->SetStreamId(0); op_desc->SetStreamId(0);
@@ -128,4 +136,8 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) {
executor.InitCallback(node_state.get(), callback); executor.InitCallback(node_state.get(), callback);
ExecutionEngine execution_engine; ExecutionEngine execution_engine;
EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
CompileNodeExecutor node_executor;
node_item->node_executor = &node_executor;
EXPECT_EQ(TaskCompileEngine::Compile(*node_state, &execution_context), SUCCESS);
} }

Loading…
Cancel
Save