| @@ -71,6 +71,7 @@ struct GraphExecutionContext { | |||||
| std::atomic_bool is_eos_; | std::atomic_bool is_eos_; | ||||
| long profiling_level = 0; | long profiling_level = 0; | ||||
| long iteration = 0; | long iteration = 0; | ||||
| void *global_step = nullptr; | |||||
| private: | private: | ||||
| Status status = SUCCESS; | Status status = SUCCESS; | ||||
| @@ -33,6 +33,9 @@ HybridModelExecutor::~HybridModelExecutor() { | |||||
| if (context_.rt_gen_context != nullptr) { | if (context_.rt_gen_context != nullptr) { | ||||
| (void) rtCtxDestroy(context_.rt_gen_context); | (void) rtCtxDestroy(context_.rt_gen_context); | ||||
| } | } | ||||
| if (context_.global_step != nullptr) { | |||||
| (void) rtFree(context_.global_step); | |||||
| } | |||||
| } | } | ||||
| Status HybridModelExecutor::Init() { | Status HybridModelExecutor::Init() { | ||||
| @@ -47,6 +50,8 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| auto root_graph_item = model_->GetRootGraphItem(); | auto root_graph_item = model_->GetRootGraphItem(); | ||||
| GE_CHECK_NOTNULL(root_graph_item); | GE_CHECK_NOTNULL(root_graph_item); | ||||
| GE_CHK_RT_RET(rtMemcpy(context_.global_step, sizeof(uint64_t), &context_.iteration, | |||||
| sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE)); | |||||
| SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | ||||
| auto ret = ExecuteGraphInternal(executor, args); | auto ret = ExecuteGraphInternal(executor, args); | ||||
| Cleanup(); | Cleanup(); | ||||
| @@ -97,6 +102,7 @@ Status HybridModelExecutor::InitExecutionContext() { | |||||
| GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); | ||||
| GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | ||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | ||||
| GE_CHK_RT_RET(rtMalloc(&context_.global_step, sizeof(uint64_t), RT_MEMORY_HBM)); | |||||
| context_.stream = stream_; | context_.stream = stream_; | ||||
| context_.model = model_; | context_.model = model_; | ||||
| @@ -231,12 +231,6 @@ Status NodeDoneCallback::DumpDynamicNode() { | |||||
| uint32_t model_id = model->GetModelId(); | uint32_t model_id = model->GetModelId(); | ||||
| dump_op_.SetDynamicModelInfo(dynamic_model_name, model_id); | dump_op_.SetDynamicModelInfo(dynamic_model_name, model_id); | ||||
| void *global_step = nullptr; | |||||
| TensorValue *varible_global_step = context_->GetVariable(NODE_NAME_GLOBAL_STEP); | |||||
| if (varible_global_step != nullptr) { | |||||
| global_step = const_cast<void *>(varible_global_step->GetData()); | |||||
| } | |||||
| void *loop_per_iter = nullptr; | void *loop_per_iter = nullptr; | ||||
| TensorValue *varible_loop_per_iter = context_->GetVariable(NODE_NAME_FLOWCTRL_LOOP_PER_ITER); | TensorValue *varible_loop_per_iter = context_->GetVariable(NODE_NAME_FLOWCTRL_LOOP_PER_ITER); | ||||
| if (varible_loop_per_iter != nullptr) { | if (varible_loop_per_iter != nullptr) { | ||||
| @@ -248,6 +242,7 @@ Status NodeDoneCallback::DumpDynamicNode() { | |||||
| if (varible_loop_cond != nullptr) { | if (varible_loop_cond != nullptr) { | ||||
| loop_cond = const_cast<void *>(varible_loop_cond->GetData()); | loop_cond = const_cast<void *>(varible_loop_cond->GetData()); | ||||
| } | } | ||||
| void *global_step = context_->GetExecutionContext()->global_step; | |||||
| dump_op_.SetLoopAddr(global_step, loop_per_iter, loop_cond); | dump_op_.SetLoopAddr(global_step, loop_per_iter, loop_cond); | ||||
| GE_CHK_STATUS_RET(dump_op_.LaunchDumpOp(), "Failed to launch dump op in hybird model"); | GE_CHK_STATUS_RET(dump_op_.LaunchDumpOp(), "Failed to launch dump op in hybird model"); | ||||
| @@ -126,11 +126,7 @@ Status KnownNodeTask::Init(TaskContext &context) { | |||||
| auto dump_properties = context.GetDumpProperties(); | auto dump_properties = context.GetDumpProperties(); | ||||
| if (dump_properties.IsDumpOpen() || dump_properties.IsOpDebugOpen()) { | if (dump_properties.IsDumpOpen() || dump_properties.IsOpDebugOpen()) { | ||||
| davinci_model_->SetDumpProperties(dump_properties); | davinci_model_->SetDumpProperties(dump_properties); | ||||
| void *global_step = nullptr; | |||||
| TensorValue *varible_global_step = context.GetVariable(NODE_NAME_GLOBAL_STEP); | |||||
| if (varible_global_step != nullptr) { | |||||
| global_step = varible_global_step->MutableData(); | |||||
| } | |||||
| void *global_step = context.GetExecutionContext()->global_step;; | |||||
| davinci_model_->SetKnownShapeGlobalStep(global_step); | davinci_model_->SetKnownShapeGlobalStep(global_step); | ||||
| } | } | ||||
| int32_t device_id = 0; | int32_t device_id = 0; | ||||
| @@ -30,6 +30,7 @@ | |||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| #include "hybrid/executor/hybrid_model_executor.h" | |||||
| #include "hybrid/node_executor/aicore/aicore_task_builder.h" | #include "hybrid/node_executor/aicore/aicore_task_builder.h" | ||||
| #include "graph/load/model_manager/tbe_handle_store.h" | #include "graph/load/model_manager/tbe_handle_store.h" | ||||
| #include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
| @@ -242,4 +243,20 @@ TEST_F(UtestGeHybrid, init_weight_success) { | |||||
| ge_sub_model->SetWeight(weight_buffer); | ge_sub_model->SetWeight(weight_buffer); | ||||
| ret = hybrid_model_builder.InitWeights(); | ret = hybrid_model_builder.InitWeights(); | ||||
| ASSERT_EQ(ret,PARAM_INVALID); | ASSERT_EQ(ret,PARAM_INVALID); | ||||
| } | |||||
| } | |||||
| TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||||
| ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | |||||
| GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | |||||
| //auto graph_item = std::unique_ptr<GraphItem>(new(std::nothrow)GraphItem()); | |||||
| HybridModel model(root_model); | |||||
| //model.root_graph_item_ = graph_item; | |||||
| HybridModel *model_ptr = &model; | |||||
| uint32_t device_id = 0; | |||||
| rtStream_t stream; | |||||
| HybridModelExecutor executor(model_ptr, device_id, stream); | |||||
| executor.Init(); | |||||
| HybridModelExecutor::ExecuteArgs args; | |||||
| executor.Execute(args); | |||||
| } | |||||