Browse Source

Pre Merge pull request !1932 from 赵之轩/my_dev

pull/1932/MERGE
赵之轩 Gitee 4 years ago
parent
commit
2582470b82
3 changed files with 9 additions and 7 deletions
  1. +0
    -4
      ge/hybrid/executor/hybrid_model_executor.cc
  2. +0
    -1
      ge/hybrid/executor/hybrid_model_pipeline_executor.cc
  3. +9
    -2
      ge/hybrid/executor/worker/task_compile_engine.cc

+ 0
- 4
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -33,9 +33,6 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id,
} }


HybridModelExecutor::~HybridModelExecutor() { HybridModelExecutor::~HybridModelExecutor() {
if (context_.rt_gen_context != nullptr) {
(void) rtCtxDestroy(context_.rt_gen_context);
}
} }


Status HybridModelExecutor::Init() { Status HybridModelExecutor::Init() {
@@ -139,7 +136,6 @@ Status HybridModelExecutor::Cleanup() {


Status HybridModelExecutor::InitExecutionContext() { Status HybridModelExecutor::InitExecutionContext() {
GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context));
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0));
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context));


context_.global_step = model_->GetGlobalStep(); context_.global_step = model_->GetGlobalStep();


+ 0
- 1
ge/hybrid/executor/hybrid_model_pipeline_executor.cc View File

@@ -175,7 +175,6 @@ HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uin
} }


Status StageExecutor::InitExecutionContext() { Status StageExecutor::InitExecutionContext() {
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0));
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context));


context_.model = model_; context_.model = model_;


+ 9
- 2
ge/hybrid/executor/worker/task_compile_engine.cc View File

@@ -21,10 +21,17 @@
namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) {
const auto &node_item = *node_state.GetNodeItem();
GE_CHECK_NOTNULL(context); GE_CHECK_NOTNULL(context);
rtContext_t rt_gen_context = nullptr;
GE_CHK_RT_RET(rtCtxCreate(&rt_gen_context, RT_CTX_GEN_MODE, 0));
std::function<void()> callback = [&]() {
(void) rtCtxDestroy(rt_gen_context);
GE_CHK_RT(rtCtxSetCurrent(context->rt_context));
};
GE_MAKE_GUARD(rt_gen_context, callback);

const auto &node_item = *node_state.GetNodeItem();
RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start"); RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "[Compile] Start");
GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context));


if (context->ge_context != nullptr) { if (context->ge_context != nullptr) {
GetThreadLocalContext() = *context->ge_context; GetThreadLocalContext() = *context->ge_context;


Loading…
Cancel
Save