diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc index d8939175..9bf70d26 100755 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -41,6 +41,8 @@ HybridModelExecutor::~HybridModelExecutor() { Status HybridModelExecutor::Init() { GELOGD("Start to init HybridGraphEngine."); GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); + root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); + GE_CHECK_NOTNULL(root_graph_executor_); GELOGD("HybridGraphEngine initialized successfully."); return SUCCESS; } @@ -60,8 +62,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); } - SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); - auto ret = ExecuteGraphInternal(executor, args); + auto ret = ExecuteGraphInternal(args); Cleanup(); RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); GELOGD("Model executed successfully."); @@ -69,6 +70,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { context_.profiler->Dump(std::cout); context_.profiler->Reset(); } + root_graph_executor_->ReleaseContext(); context_.iteration += 1; if (ret == END_OF_SEQUENCE) { @@ -79,8 +81,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { return SUCCESS; } -Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, - HybridModelExecutor::ExecuteArgs &args) { +Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) { RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); @@ -94,7 +95,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); } - HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs), + HYBRID_CHK_STATUS_RET(root_graph_executor_->ExecuteAsync(args.inputs, args.input_desc, args.outputs), "Failed to execute partitioned call."); RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); @@ -103,7 +104,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, } if (!model_->IsSingleOp()) { - Status ret = executor.Synchronize(); + Status ret = root_graph_executor_->Synchronize(); if (ret != ge::SUCCESS) { auto model_manager = ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); @@ -123,7 +124,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, } args.outputs.clear(); - HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); + HYBRID_CHK_STATUS_RET(root_graph_executor_->GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); return SUCCESS; } diff --git a/ge/hybrid/executor/hybrid_model_executor.h b/ge/hybrid/executor/hybrid_model_executor.h index 566043d9..102e4f8b 100644 --- a/ge/hybrid/executor/hybrid_model_executor.h +++ b/ge/hybrid/executor/hybrid_model_executor.h @@ -48,7 +48,7 @@ class HybridModelExecutor { Status Execute(ExecuteArgs &args); private: - Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args); + Status ExecuteGraphInternal(ExecuteArgs &args); Status Cleanup(); Status InitExecutionContext(); static Status ResetExecutionContext(GraphExecutionContext &context); @@ -58,6 +58,7 @@ class HybridModelExecutor { uint32_t device_id_; rtStream_t stream_; GraphExecutionContext context_; + std::unique_ptr root_graph_executor_; }; } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index 85f9e4c3..b80b60b0 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -177,6 +177,10 @@ struct NodeState { void SetTaskContext(std::shared_ptr &task_context); std::shared_ptr GetTaskContext(); + void SetSkipInferShape(bool skip_infershape) { skip_infershape_ = skip_infershape; } + + bool MaySkipShapeInference() const { return skip_infershape_; } + private: bool IsScheduleReady() const; void SetDataSchedule(const NodeState &node_state, const std::function &ready); @@ -204,6 +208,7 @@ struct NodeState { int merge_index_ = -1; // Use for Execute (Reset after Executed). int switch_index_ = -1; // Use for Schedule (Reset after Prepared). int group_ = -1; + bool skip_infershape_ = false; }; } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index 7429acc5..6979d05f 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -103,6 +103,14 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vectorGetOrCreateNodeState(input_node); GE_CHECK_NOTNULL(node_state); node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); + auto op_desc = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex); + GE_CHECK_NOTNULL(output_desc); + output_desc->SetShape(tensor_desc->GetShape()); + output_desc->SetOriginShape(tensor_desc->GetOriginShape()); + output_desc->SetDataType(tensor_desc->GetDataType()); + node_state->SetSkipInferShape(true); } } diff --git a/ge/hybrid/executor/subgraph_executor.h b/ge/hybrid/executor/subgraph_executor.h index e4c0debe..76732c37 100644 --- a/ge/hybrid/executor/subgraph_executor.h +++ b/ge/hybrid/executor/subgraph_executor.h @@ -41,6 +41,8 @@ class SubgraphExecutor { Status PartialExecuteAsync(int task_group); + void ReleaseContext() { subgraph_context_.reset(nullptr); } + /** * Execute subgraph async, output tensor address(not data) and output tensor descriptions are * valid after this method returned diff --git a/ge/hybrid/executor/worker/shape_inference_engine.cc b/ge/hybrid/executor/worker/shape_inference_engine.cc index a2efbb25..50dc389c 100755 --- a/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -68,8 +68,9 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { } // Do shape inference + // Skipping infer shape of input node. GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); - { + if (!node_state.MaySkipShapeInference()) { RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), "[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 67642f2e..90a6362c 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -44,29 +44,63 @@ using std::vector; namespace ge { namespace { const size_t kDataOutputNum = 1; +const uint32_t kInputIndexOfData = 0; const uint32_t kOutputIndexOfData = 0; constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; -Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { +Status CheckHostMem(const std::vector &dependencies, const NodePtr &node, bool &is_host_mem) { + auto op_desc = node->GetOpDesc(); + for (const auto &input_name : dependencies) { + int input_index = op_desc->GetInputIndexByName(input_name); + if (input_index < 0) { + GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.", + node->GetName().c_str(), input_name.c_str()); + REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.", + node->GetName().c_str(), input_name.c_str()); + return INTERNAL_ERROR; + } + + const auto &src_node = NodeUtils::GetInDataNodeByIndex(*node, input_index); + GE_CHECK_NOTNULL(src_node); + auto src_op_desc = src_node->GetOpDesc(); + GE_CHECK_NOTNULL(src_op_desc); + if (src_op_desc->GetType() == DATA) { + auto tensor = src_op_desc->MutableInputDesc(kInputIndexOfData); + if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) { + GELOGD("Get hostmem from node %s, inputname: %s.", src_node->GetName().c_str(), input_name.c_str()); + continue; + } + } + is_host_mem = false; + return SUCCESS; + } + is_host_mem = true; + return SUCCESS; +} + +Status CheckInferDepend(GeModelPtr &ge_model, bool &is_infer_depend, bool &is_host_mem) { auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); GE_CHECK_NOTNULL(comp_graph); for (const auto &node : comp_graph->GetAllNodes()) { + GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); const auto &depends = op_desc->GetOpInferDepends(); bool support_dynamic_shape = false; (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); if (!depends.empty() && support_dynamic_shape) { - flag = true; - return SUCCESS; + is_infer_depend = true; + return CheckHostMem(depends, node, is_host_mem); } } return SUCCESS; } Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { - bool infer_depend_flag = false; - GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); + bool is_infer_depend = false; + bool is_host_mem = false; + GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed."); + bool need_d2h_cpy = is_infer_depend && !is_host_mem; auto tasks = ge_model->GetModelTaskDefPtr()->task(); int32_t kernel_task_num = 0; for (int i = 0; i < tasks.size(); ++i) { @@ -76,7 +110,7 @@ Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { tasks[i].kernel_with_handle().context(); auto kernel_type = static_cast(context.kernel_type()); if (kernel_type == ccKernelType::TE) { - if (infer_depend_flag) { + if (need_d2h_cpy) { flag = true; return SUCCESS; } @@ -517,7 +551,8 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { auto ge_model = model_helper_.GetGeModel(); GE_CHECK_NOTNULL(ge_model); bool infer_depend_flag = false; - GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); + bool is_host_mem = false; + GE_CHK_STATUS_RET(CheckInferDepend(ge_model, infer_depend_flag, is_host_mem), "[Check][InferDepend] failed."); if (infer_depend_flag) { // construct single_op, do single op with HybridModelExecutor GELOGD("Init hybrid model params of single op, and will do execute with hybrid model executor."); diff --git a/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc b/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc index d2679439..98bb78f2 100644 --- a/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/hybrid_model_async_executor_unittest.cc @@ -87,21 +87,20 @@ TEST_F(UtestHybridModelAsyncExecutor, BuildDeviceTensor) { ASSERT_EQ(size, 100); } -TEST_F(UtestHybridModelAsyncExecutor, Test_execute_internal) { +TEST_F(UtestHybridModelAsyncExecutor, Test_execute) { ComputeGraphPtr graph = std::make_shared("test"); GeRootModelPtr ge_root_model = make_shared(graph); ge_root_model->SetModelName("test_name"); HybridModel hybrid_model(ge_root_model); + hybrid_model.root_graph_item_.reset(new GraphItem); HybridModelExecutor executor(&hybrid_model, 0, nullptr); ASSERT_EQ(executor.Init(), SUCCESS); auto &context = executor.context_; - GraphItem graph_item; - SubgraphExecutor subgraph_executor(&graph_item, &context); HybridModelExecutor::ExecuteArgs args; std::pair> eof_entry; eof_entry.first = nullptr; context.callback_manager->callback_queue_.Push(eof_entry); - ASSERT_EQ(executor.ExecuteGraphInternal(subgraph_executor, args), SUCCESS); + ASSERT_EQ(executor.Execute(args), SUCCESS); } } // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 228af832..4f14f628 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -329,6 +329,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor) { ComputeGraphPtr compute_graph = MakeShared("abc"); GeRootModelPtr root_model = MakeShared(compute_graph); HybridModel model(root_model); + model.root_graph_item_.reset(new GraphItem); HybridModel *model_ptr = &model; uint32_t device_id = 0; diff --git a/tests/ut/ge/single_op/single_op_model_unittest.cc b/tests/ut/ge/single_op/single_op_model_unittest.cc index a2c1cb02..1975f9f4 100644 --- a/tests/ut/ge/single_op/single_op_model_unittest.cc +++ b/tests/ut/ge/single_op/single_op_model_unittest.cc @@ -17,12 +17,11 @@ #include #include +#define protected public +#define private public #include "graph/load/model_manager/model_utils.h" #include "graph/utils/graph_utils.h" #include "runtime/rt.h" - -#define protected public -#define private public #include "single_op/single_op_model.h" #include "single_op/task/tbe_task_builder.h" #include "single_op/task/rts_kernel_task_builder.h" @@ -30,14 +29,19 @@ #include "framework/common/helper/model_helper.h" #include "single_op/single_op.h" #include "single_op/stream_resource.h" +#include "graph/passes/graph_builder_utils.h" +#include "graph/op_desc_impl.h" #undef private #undef protected -#include "graph/passes/graph_builder_utils.h" using namespace std; using namespace testing; using namespace ge; +namespace { +constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; +} // namespace + class UtestSingleOpModel : public testing::Test { protected: void SetUp() {} @@ -208,12 +212,22 @@ TEST_F(UtestSingleOpModel, test_build_dynamic_op) { model.model_helper_.model_ = ge::MakeShared(); // make graph - auto compute_graph = make_shared("graph"); - auto data_op = make_shared("Data", DATA); - auto data_node = compute_graph->AddNode(data_op); + ut::GraphBuilder builder = ut::GraphBuilder("graph"); + auto data = builder.AddNode("Data", "Data", 1, 1); + auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); + auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); + builder.AddDataEdge(data, 0, transdata, 0); + builder.AddDataEdge(transdata, 0, netoutput, 0); + auto compute_graph = builder.GetGraph(); + auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); model.model_helper_.model_->SetGraph(graph); + auto op_desc = transdata->GetOpDesc(); + const vector depend_names = { "Data" }; + op_desc->SetOpInferDepends(depend_names); + (void)AttrUtils::SetBool(op_desc, kAttrSupportDynamicShape, true); + // set task_def auto model_task_def = make_shared(); domi::TaskDef *task_def = model_task_def->add_task(); @@ -227,6 +241,15 @@ TEST_F(UtestSingleOpModel, test_build_dynamic_op) { DynamicSingleOp dynamic_single_op(0, &stream_mu_, nullptr); StreamResource res((uintptr_t)1); model.BuildDynamicOp(res, dynamic_single_op); + + op_desc->impl_->input_name_idx_["Data"] = 0; + model.BuildDynamicOp(res, dynamic_single_op); + + auto tensor = std::make_shared(); + auto data_desc = data->GetOpDesc(); + auto tensor_desc = data_desc->MutableInputDesc(0); + AttrUtils::SetTensor(tensor_desc, "_value", tensor); + model.BuildDynamicOp(res, dynamic_single_op); } TEST_F(UtestSingleOpModel, test_host_mem) {