@@ -71,6 +71,7 @@ struct GraphExecutionContext { | |||||
std::atomic_bool is_eos_; | std::atomic_bool is_eos_; | ||||
long profiling_level = 0; | long profiling_level = 0; | ||||
long iteration = 0; | long iteration = 0; | ||||
void *global_step = nullptr; | |||||
private: | private: | ||||
Status status = SUCCESS; | Status status = SUCCESS; | ||||
@@ -33,6 +33,9 @@ HybridModelExecutor::~HybridModelExecutor() { | |||||
if (context_.rt_gen_context != nullptr) { | if (context_.rt_gen_context != nullptr) { | ||||
(void) rtCtxDestroy(context_.rt_gen_context); | (void) rtCtxDestroy(context_.rt_gen_context); | ||||
} | } | ||||
if (context_.global_step != nullptr) { | |||||
(void) rtFree(context_.global_step); | |||||
} | |||||
} | } | ||||
Status HybridModelExecutor::Init() { | Status HybridModelExecutor::Init() { | ||||
@@ -47,6 +50,8 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
auto root_graph_item = model_->GetRootGraphItem(); | auto root_graph_item = model_->GetRootGraphItem(); | ||||
GE_CHECK_NOTNULL(root_graph_item); | GE_CHECK_NOTNULL(root_graph_item); | ||||
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | |||||
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | |||||
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | ||||
auto ret = ExecuteGraphInternal(executor, args); | auto ret = ExecuteGraphInternal(executor, args); | ||||
Cleanup(); | Cleanup(); | ||||
@@ -97,6 +102,7 @@ Status HybridModelExecutor::InitExecutionContext() { | |||||
GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | ||||
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | ||||
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | ||||
GE_CHK_RT_RET(rtMalloc(&context_.global_step, sizeof(uint64_t), RT_MEMORY_HBM)); | |||||
context_.stream = stream_; | context_.stream = stream_; | ||||
context_.model = model_; | context_.model = model_; | ||||
@@ -231,12 +231,6 @@ Status NodeDoneCallback::DumpDynamicNode() { | |||||
uint32_t model_id = model->GetModelId(); | uint32_t model_id = model->GetModelId(); | ||||
dump_op_.SetDynamicModelInfo(dynamic_model_name, model_id); | dump_op_.SetDynamicModelInfo(dynamic_model_name, model_id); | ||||
void *global_step = nullptr; | |||||
TensorValue *varible_global_step = context_->GetVariable(NODE_NAME_GLOBAL_STEP); | |||||
if (varible_global_step != nullptr) { | |||||
global_step = const_cast<void *>(varible_global_step->GetData()); | |||||
} | |||||
void *loop_per_iter = nullptr; | void *loop_per_iter = nullptr; | ||||
TensorValue *varible_loop_per_iter = context_->GetVariable(NODE_NAME_FLOWCTRL_LOOP_PER_ITER); | TensorValue *varible_loop_per_iter = context_->GetVariable(NODE_NAME_FLOWCTRL_LOOP_PER_ITER); | ||||
if (varible_loop_per_iter != nullptr) { | if (varible_loop_per_iter != nullptr) { | ||||
@@ -248,6 +242,7 @@ Status NodeDoneCallback::DumpDynamicNode() { | |||||
if (varible_loop_cond != nullptr) { | if (varible_loop_cond != nullptr) { | ||||
loop_cond = const_cast<void *>(varible_loop_cond->GetData()); | loop_cond = const_cast<void *>(varible_loop_cond->GetData()); | ||||
} | } | ||||
void *global_step = context_->GetExecutionContext()->global_step; | |||||
dump_op_.SetLoopAddr(global_step, loop_per_iter, loop_cond); | dump_op_.SetLoopAddr(global_step, loop_per_iter, loop_cond); | ||||
GE_CHK_STATUS_RET(dump_op_.LaunchDumpOp(), "Failed to launch dump op in hybird model"); | GE_CHK_STATUS_RET(dump_op_.LaunchDumpOp(), "Failed to launch dump op in hybird model"); | ||||
@@ -126,11 +126,7 @@ Status KnownNodeTask::Init(TaskContext &context) { | |||||
auto dump_properties = context.GetDumpProperties(); | auto dump_properties = context.GetDumpProperties(); | ||||
if (dump_properties.IsDumpOpen() || dump_properties.IsOpDebugOpen()) { | if (dump_properties.IsDumpOpen() || dump_properties.IsOpDebugOpen()) { | ||||
davinci_model_->SetDumpProperties(dump_properties); | davinci_model_->SetDumpProperties(dump_properties); | ||||
void *global_step = nullptr; | |||||
TensorValue *varible_global_step = context.GetVariable(NODE_NAME_GLOBAL_STEP); | |||||
if (varible_global_step != nullptr) { | |||||
global_step = varible_global_step->MutableData(); | |||||
} | |||||
void *global_step = context.GetExecutionContext()->global_step; | |||||
davinci_model_->SetKnownShapeGlobalStep(global_step); | davinci_model_->SetKnownShapeGlobalStep(global_step); | ||||
} | } | ||||
int32_t device_id = 0; | int32_t device_id = 0; | ||||
@@ -30,6 +30,7 @@ | |||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
#include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
#include "hybrid/executor/hybrid_model_executor.h" | |||||
#include "hybrid/node_executor/aicore/aicore_task_builder.h" | #include "hybrid/node_executor/aicore/aicore_task_builder.h" | ||||
#include "graph/load/model_manager/tbe_handle_store.h" | #include "graph/load/model_manager/tbe_handle_store.h" | ||||
#include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
@@ -242,4 +243,16 @@ TEST_F(UtestGeHybrid, init_weight_success) { | |||||
ge_sub_model->SetWeight(weight_buffer); | ge_sub_model->SetWeight(weight_buffer); | ||||
ret = hybrid_model_builder.InitWeights(); | ret = hybrid_model_builder.InitWeights(); | ||||
ASSERT_EQ(ret,PARAM_INVALID); | ASSERT_EQ(ret,PARAM_INVALID); | ||||
} | |||||
} | |||||
TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | |||||
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | |||||
HybridModel model(root_model); | |||||
HybridModel *model_ptr = &model; | |||||
uint32_t device_id = 0; | |||||
rtStream_t stream; | |||||
HybridModelExecutor executor(model_ptr, device_id, stream); | |||||
executor.Init(); | |||||
} |