| @@ -41,6 +41,8 @@ HybridModelExecutor::~HybridModelExecutor() { | |||||
| Status HybridModelExecutor::Init() { | Status HybridModelExecutor::Init() { | ||||
| GELOGD("Start to init HybridGraphEngine."); | GELOGD("Start to init HybridGraphEngine."); | ||||
| GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | ||||
| root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); | |||||
| GE_CHECK_NOTNULL(root_graph_executor_); | |||||
| GELOGD("HybridGraphEngine initialized successfully."); | GELOGD("HybridGraphEngine initialized successfully."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -60,8 +62,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | ||||
| sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | ||||
| } | } | ||||
| SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | |||||
| auto ret = ExecuteGraphInternal(executor, args); | |||||
| auto ret = ExecuteGraphInternal(args); | |||||
| Cleanup(); | Cleanup(); | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | ||||
| GELOGD("Model executed successfully."); | GELOGD("Model executed successfully."); | ||||
| @@ -79,8 +80,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| HybridModelExecutor::ExecuteArgs &args) { | |||||
| Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) { | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); | ||||
| GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); | GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); | ||||
| @@ -94,7 +94,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); | GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); | ||||
| } | } | ||||
| HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||||
| HYBRID_CHK_STATUS_RET(root_graph_executor_->ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||||
| "Failed to execute partitioned call."); | "Failed to execute partitioned call."); | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); | ||||
| @@ -103,7 +103,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| } | } | ||||
| if (!model_->IsSingleOp()) { | if (!model_->IsSingleOp()) { | ||||
| Status ret = executor.Synchronize(); | |||||
| Status ret = root_graph_executor_->Synchronize(); | |||||
| if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
| auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| @@ -123,7 +123,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| } | } | ||||
| args.outputs.clear(); | args.outputs.clear(); | ||||
| HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||||
| HYBRID_CHK_STATUS_RET(root_graph_executor_->GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -48,7 +48,7 @@ class HybridModelExecutor { | |||||
| Status Execute(ExecuteArgs &args); | Status Execute(ExecuteArgs &args); | ||||
| private: | private: | ||||
| Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args); | |||||
| Status ExecuteGraphInternal(ExecuteArgs &args); | |||||
| Status Cleanup(); | Status Cleanup(); | ||||
| Status InitExecutionContext(); | Status InitExecutionContext(); | ||||
| static Status ResetExecutionContext(GraphExecutionContext &context); | static Status ResetExecutionContext(GraphExecutionContext &context); | ||||
| @@ -58,6 +58,7 @@ class HybridModelExecutor { | |||||
| uint32_t device_id_; | uint32_t device_id_; | ||||
| rtStream_t stream_; | rtStream_t stream_; | ||||
| GraphExecutionContext context_; | GraphExecutionContext context_; | ||||
| std::unique_ptr<SubgraphExecutor> root_graph_executor_; | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -44,20 +44,56 @@ using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const size_t kDataOutputNum = 1; | const size_t kDataOutputNum = 1; | ||||
| const uint32_t kInputIndexOfData = 0; | |||||
| const uint32_t kOutputIndexOfData = 0; | const uint32_t kOutputIndexOfData = 0; | ||||
| constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | ||||
| Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node, bool &flag) { | |||||
| for (const auto &input_name : dependencies) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| int input_index = op_desc->GetInputIndexByName(input_name); | |||||
| if (input_index < 0) { | |||||
| GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.", | |||||
| node->GetName().c_str(), input_name.c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.", | |||||
| node->GetName().c_str(), input_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| const auto &in_anchor = node->GetInDataAnchor(input_index); | |||||
| GE_CHECK_NOTNULL(in_anchor); | |||||
| const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||||
| const auto &src_node = peer_out_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| auto src_op_desc = src_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(src_op_desc); | |||||
| if (src_op_desc->GetType() == DATA) { | |||||
| auto tensor = src_op_desc->MutableInputDesc(kInputIndexOfData); | |||||
| if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) { | |||||
| GELOGD("Get hostmem from node %s, inputname: %s.", src_node->GetName().c_str(), input_name.c_str()); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| flag = false; | |||||
| return SUCCESS; | |||||
| } | |||||
| flag = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | ||||
| auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | ||||
| GE_CHECK_NOTNULL(comp_graph); | GE_CHECK_NOTNULL(comp_graph); | ||||
| for (const auto &node : comp_graph->GetAllNodes()) { | for (const auto &node : comp_graph->GetAllNodes()) { | ||||
| GE_CHECK_NOTNULL(node); | |||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| const auto &depends = op_desc->GetOpInferDepends(); | const auto &depends = op_desc->GetOpInferDepends(); | ||||
| bool support_dynamic_shape = false; | bool support_dynamic_shape = false; | ||||
| (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); | (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); | ||||
| if (!depends.empty() && support_dynamic_shape) { | if (!depends.empty() && support_dynamic_shape) { | ||||
| flag = true; | |||||
| CheckHostMem(depends, node, flag); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } | } | ||||
| @@ -92,16 +92,15 @@ TEST_F(UtestHybridModelAsyncExecutor, Test_execute_internal) { | |||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | ||||
| ge_root_model->SetModelName("test_name"); | ge_root_model->SetModelName("test_name"); | ||||
| HybridModel hybrid_model(ge_root_model); | HybridModel hybrid_model(ge_root_model); | ||||
| hybrid_model.root_graph_item_.reset(new GraphItem); | |||||
| HybridModelExecutor executor(&hybrid_model, 0, nullptr); | HybridModelExecutor executor(&hybrid_model, 0, nullptr); | ||||
| ASSERT_EQ(executor.Init(), SUCCESS); | ASSERT_EQ(executor.Init(), SUCCESS); | ||||
| auto &context = executor.context_; | auto &context = executor.context_; | ||||
| GraphItem graph_item; | |||||
| SubgraphExecutor subgraph_executor(&graph_item, &context); | |||||
| HybridModelExecutor::ExecuteArgs args; | HybridModelExecutor::ExecuteArgs args; | ||||
| std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> eof_entry; | std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> eof_entry; | ||||
| eof_entry.first = nullptr; | eof_entry.first = nullptr; | ||||
| context.callback_manager->callback_queue_.Push(eof_entry); | context.callback_manager->callback_queue_.Push(eof_entry); | ||||
| ASSERT_EQ(executor.ExecuteGraphInternal(subgraph_executor, args), SUCCESS); | |||||
| ASSERT_EQ(executor.ExecuteGraphInternal(args), SUCCESS); | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -330,6 +330,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||||
| ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | ||||
| GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | ||||
| HybridModel model(root_model); | HybridModel model(root_model); | ||||
| model.root_graph_item_.reset(new GraphItem); | |||||
| HybridModel *model_ptr = &model; | HybridModel *model_ptr = &model; | ||||
| uint32_t device_id = 0; | uint32_t device_id = 0; | ||||