@@ -41,6 +41,8 @@ HybridModelExecutor::~HybridModelExecutor() { | |||||
Status HybridModelExecutor::Init() { | Status HybridModelExecutor::Init() { | ||||
GELOGD("Start to init HybridGraphEngine."); | GELOGD("Start to init HybridGraphEngine."); | ||||
GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | ||||
root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); | |||||
GE_CHECK_NOTNULL(root_graph_executor_); | |||||
GELOGD("HybridGraphEngine initialized successfully."); | GELOGD("HybridGraphEngine initialized successfully."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -60,8 +62,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | ||||
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | ||||
} | } | ||||
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | |||||
auto ret = ExecuteGraphInternal(executor, args); | |||||
auto ret = ExecuteGraphInternal(args); | |||||
Cleanup(); | Cleanup(); | ||||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | ||||
GELOGD("Model executed successfully."); | GELOGD("Model executed successfully."); | ||||
@@ -79,8 +80,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
HybridModelExecutor::ExecuteArgs &args) { | |||||
Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) { | |||||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); | ||||
GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); | GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); | ||||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); | ||||
@@ -94,7 +94,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); | GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); | ||||
} | } | ||||
HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||||
HYBRID_CHK_STATUS_RET(root_graph_executor_->ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||||
"Failed to execute partitioned call."); | "Failed to execute partitioned call."); | ||||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); | ||||
@@ -103,7 +103,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
} | } | ||||
if (!model_->IsSingleOp()) { | if (!model_->IsSingleOp()) { | ||||
Status ret = executor.Synchronize(); | |||||
Status ret = root_graph_executor_->Synchronize(); | |||||
if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
@@ -123,7 +123,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
} | } | ||||
args.outputs.clear(); | args.outputs.clear(); | ||||
HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||||
HYBRID_CHK_STATUS_RET(root_graph_executor_->GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -48,7 +48,7 @@ class HybridModelExecutor { | |||||
Status Execute(ExecuteArgs &args); | Status Execute(ExecuteArgs &args); | ||||
private: | private: | ||||
Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args); | |||||
Status ExecuteGraphInternal(ExecuteArgs &args); | |||||
Status Cleanup(); | Status Cleanup(); | ||||
Status InitExecutionContext(); | Status InitExecutionContext(); | ||||
static Status ResetExecutionContext(GraphExecutionContext &context); | static Status ResetExecutionContext(GraphExecutionContext &context); | ||||
@@ -58,6 +58,7 @@ class HybridModelExecutor { | |||||
uint32_t device_id_; | uint32_t device_id_; | ||||
rtStream_t stream_; | rtStream_t stream_; | ||||
GraphExecutionContext context_; | GraphExecutionContext context_; | ||||
std::unique_ptr<SubgraphExecutor> root_graph_executor_; | |||||
}; | }; | ||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge | ||||
@@ -44,20 +44,56 @@ using std::vector; | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
const size_t kDataOutputNum = 1; | const size_t kDataOutputNum = 1; | ||||
const uint32_t kInputIndexOfData = 0; | |||||
const uint32_t kOutputIndexOfData = 0; | const uint32_t kOutputIndexOfData = 0; | ||||
constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | ||||
Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node, bool &flag) { | |||||
for (const auto &input_name : dependencies) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
int input_index = op_desc->GetInputIndexByName(input_name); | |||||
if (input_index < 0) { | |||||
GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.", | |||||
node->GetName().c_str(), input_name.c_str()); | |||||
REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.", | |||||
node->GetName().c_str(), input_name.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
const auto &in_anchor = node->GetInDataAnchor(input_index); | |||||
GE_CHECK_NOTNULL(in_anchor); | |||||
const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
GE_CHECK_NOTNULL(peer_out_anchor); | |||||
const auto &src_node = peer_out_anchor->GetOwnerNode(); | |||||
GE_CHECK_NOTNULL(src_node); | |||||
auto src_op_desc = src_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(src_op_desc); | |||||
if (src_op_desc->GetType() == DATA) { | |||||
auto tensor = src_op_desc->MutableInputDesc(kInputIndexOfData); | |||||
if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) { | |||||
GELOGD("Get hostmem from node %s, inputname: %s.", src_node->GetName().c_str(), input_name.c_str()); | |||||
continue; | |||||
} | |||||
} | |||||
flag = false; | |||||
return SUCCESS; | |||||
} | |||||
flag = true; | |||||
return SUCCESS; | |||||
} | |||||
Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | ||||
auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | ||||
GE_CHECK_NOTNULL(comp_graph); | GE_CHECK_NOTNULL(comp_graph); | ||||
for (const auto &node : comp_graph->GetAllNodes()) { | for (const auto &node : comp_graph->GetAllNodes()) { | ||||
GE_CHECK_NOTNULL(node); | |||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
const auto &depends = op_desc->GetOpInferDepends(); | const auto &depends = op_desc->GetOpInferDepends(); | ||||
bool support_dynamic_shape = false; | bool support_dynamic_shape = false; | ||||
(void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); | (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); | ||||
if (!depends.empty() && support_dynamic_shape) { | if (!depends.empty() && support_dynamic_shape) { | ||||
flag = true; | |||||
CheckHostMem(depends, node, flag); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} | } | ||||
@@ -92,16 +92,15 @@ TEST_F(UtestHybridModelAsyncExecutor, Test_execute_internal) { | |||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | ||||
ge_root_model->SetModelName("test_name"); | ge_root_model->SetModelName("test_name"); | ||||
HybridModel hybrid_model(ge_root_model); | HybridModel hybrid_model(ge_root_model); | ||||
hybrid_model.root_graph_item_.reset(new GraphItem); | |||||
HybridModelExecutor executor(&hybrid_model, 0, nullptr); | HybridModelExecutor executor(&hybrid_model, 0, nullptr); | ||||
ASSERT_EQ(executor.Init(), SUCCESS); | ASSERT_EQ(executor.Init(), SUCCESS); | ||||
auto &context = executor.context_; | auto &context = executor.context_; | ||||
GraphItem graph_item; | |||||
SubgraphExecutor subgraph_executor(&graph_item, &context); | |||||
HybridModelExecutor::ExecuteArgs args; | HybridModelExecutor::ExecuteArgs args; | ||||
std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> eof_entry; | std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> eof_entry; | ||||
eof_entry.first = nullptr; | eof_entry.first = nullptr; | ||||
context.callback_manager->callback_queue_.Push(eof_entry); | context.callback_manager->callback_queue_.Push(eof_entry); | ||||
ASSERT_EQ(executor.ExecuteGraphInternal(subgraph_executor, args), SUCCESS); | |||||
ASSERT_EQ(executor.ExecuteGraphInternal(args), SUCCESS); | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -330,6 +330,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | ||||
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | ||||
HybridModel model(root_model); | HybridModel model(root_model); | ||||
model.root_graph_item_.reset(new GraphItem); | |||||
HybridModel *model_ptr = &model; | HybridModel *model_ptr = &model; | ||||
uint32_t device_id = 0; | uint32_t device_id = 0; | ||||