diff --git a/ge/hybrid/executor/hybrid_execution_context.h b/ge/hybrid/executor/hybrid_execution_context.h index 4dc010df..003e8010 100644 --- a/ge/hybrid/executor/hybrid_execution_context.h +++ b/ge/hybrid/executor/hybrid_execution_context.h @@ -71,6 +71,7 @@ struct GraphExecutionContext { std::atomic_bool is_eos_; long profiling_level = 0; long iteration = 0; + void *global_step = nullptr; private: Status status = SUCCESS; diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc index 80b8983a..4b589a03 100755 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -33,6 +33,9 @@ HybridModelExecutor::~HybridModelExecutor() { if (context_.rt_gen_context != nullptr) { (void) rtCtxDestroy(context_.rt_gen_context); } + if (context_.global_step != nullptr) { + (void) rtFree(context_.global_step); + } } Status HybridModelExecutor::Init() { @@ -47,6 +50,8 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { auto root_graph_item = model_->GetRootGraphItem(); GE_CHECK_NOTNULL(root_graph_item); + GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, + sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); auto ret = ExecuteGraphInternal(executor, args); Cleanup(); @@ -97,6 +102,7 @@ Status HybridModelExecutor::InitExecutionContext() { GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); + GE_CHK_RT_RET(rtMalloc(&context_.global_step, sizeof(uint64_t), RT_MEMORY_HBM)); context_.stream = stream_; context_.model = model_; diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc index 673c82dd..de3bdc37 100755 --- a/ge/hybrid/executor/worker/execution_engine.cc +++ b/ge/hybrid/executor/worker/execution_engine.cc @@ -231,12 +231,6 @@ Status NodeDoneCallback::DumpDynamicNode() { uint32_t model_id = model->GetModelId(); dump_op_.SetDynamicModelInfo(dynamic_model_name, model_id); - void *global_step = nullptr; - TensorValue *varible_global_step = context_->GetVariable(NODE_NAME_GLOBAL_STEP); - if (varible_global_step != nullptr) { - global_step = const_cast(varible_global_step->GetData()); - } - void *loop_per_iter = nullptr; TensorValue *varible_loop_per_iter = context_->GetVariable(NODE_NAME_FLOWCTRL_LOOP_PER_ITER); if (varible_loop_per_iter != nullptr) { @@ -248,6 +242,7 @@ Status NodeDoneCallback::DumpDynamicNode() { if (varible_loop_cond != nullptr) { loop_cond = const_cast(varible_loop_cond->GetData()); } + void *global_step = context_->GetExecutionContext()->global_step; dump_op_.SetLoopAddr(global_step, loop_per_iter, loop_cond); GE_CHK_STATUS_RET(dump_op_.LaunchDumpOp(), "Failed to launch dump op in hybird model"); diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc index cf5ac851..bb96c275 100755 --- a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc +++ b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc @@ -126,11 +126,7 @@ Status KnownNodeTask::Init(TaskContext &context) { auto dump_properties = context.GetDumpProperties(); if (dump_properties.IsDumpOpen() || dump_properties.IsOpDebugOpen()) { davinci_model_->SetDumpProperties(dump_properties); - void *global_step = nullptr; - TensorValue *varible_global_step = context.GetVariable(NODE_NAME_GLOBAL_STEP); - if (varible_global_step != nullptr) { - global_step = varible_global_step->MutableData(); - } + void *global_step = context.GetExecutionContext()->global_step; davinci_model_->SetKnownShapeGlobalStep(global_step); } int32_t device_id = 0; diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index d7116dbc..3b5d19e6 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -30,6 +30,7 @@ #include "framework/common/debug/log.h" #include "graph/ge_context.h" #include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/hybrid_model_executor.h" #include "hybrid/node_executor/aicore/aicore_task_builder.h" #include "graph/load/model_manager/tbe_handle_store.h" #include "graph/manager/graph_mem_allocator.h" @@ -242,4 +243,16 @@ TEST_F(UtestGeHybrid, init_weight_success) { ge_sub_model->SetWeight(weight_buffer); ret = hybrid_model_builder.InitWeights(); ASSERT_EQ(ret,PARAM_INVALID); -} \ No newline at end of file +} + + TEST_F(UtestGeHybrid, hybrid_model_executor) { + ComputeGraphPtr compute_graph = MakeShared("abc"); + GeRootModelPtr root_model = MakeShared(compute_graph); + HybridModel model(root_model); + HybridModel *model_ptr = &model; + + uint32_t device_id = 0; + rtStream_t stream; + HybridModelExecutor executor(model_ptr, device_id, stream); + executor.Init(); +}