@@ -41,6 +41,8 @@ HybridModelExecutor::~HybridModelExecutor() { | |||
Status HybridModelExecutor::Init() { | |||
GELOGD("Start to init HybridGraphEngine."); | |||
GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | |||
root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); | |||
GE_CHECK_NOTNULL(root_graph_executor_); | |||
GELOGD("HybridGraphEngine initialized successfully."); | |||
return SUCCESS; | |||
} | |||
@@ -60,8 +62,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, | |||
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); | |||
} | |||
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); | |||
auto ret = ExecuteGraphInternal(executor, args); | |||
auto ret = ExecuteGraphInternal(args); | |||
Cleanup(); | |||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | |||
GELOGD("Model executed successfully."); | |||
@@ -79,8 +80,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||
return SUCCESS; | |||
} | |||
Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||
HybridModelExecutor::ExecuteArgs &args) { | |||
Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) { | |||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); | |||
GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); | |||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); | |||
@@ -94,7 +94,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||
GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); | |||
} | |||
HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||
HYBRID_CHK_STATUS_RET(root_graph_executor_->ExecuteAsync(args.inputs, args.input_desc, args.outputs), | |||
"Failed to execute partitioned call."); | |||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); | |||
@@ -103,7 +103,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||
} | |||
if (!model_->IsSingleOp()) { | |||
Status ret = executor.Synchronize(); | |||
Status ret = root_graph_executor_->Synchronize(); | |||
if (ret != ge::SUCCESS) { | |||
auto model_manager = ModelManager::GetInstance(); | |||
GE_CHECK_NOTNULL(model_manager); | |||
@@ -123,7 +123,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||
} | |||
args.outputs.clear(); | |||
HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||
HYBRID_CHK_STATUS_RET(root_graph_executor_->GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); | |||
RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); | |||
return SUCCESS; | |||
} | |||
@@ -48,7 +48,7 @@ class HybridModelExecutor { | |||
Status Execute(ExecuteArgs &args); | |||
private: | |||
Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args); | |||
Status ExecuteGraphInternal(ExecuteArgs &args); | |||
Status Cleanup(); | |||
Status InitExecutionContext(); | |||
static Status ResetExecutionContext(GraphExecutionContext &context); | |||
@@ -58,6 +58,7 @@ class HybridModelExecutor { | |||
uint32_t device_id_; | |||
rtStream_t stream_; | |||
GraphExecutionContext context_; | |||
std::unique_ptr<SubgraphExecutor> root_graph_executor_; | |||
}; | |||
} // namespace hybrid | |||
} // namespace ge | |||
@@ -44,20 +44,56 @@ using std::vector; | |||
namespace ge { | |||
namespace { | |||
const size_t kDataOutputNum = 1; | |||
const uint32_t kInputIndexOfData = 0; | |||
const uint32_t kOutputIndexOfData = 0; | |||
constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | |||
Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node, bool &flag) { | |||
for (const auto &input_name : dependencies) { | |||
auto op_desc = node->GetOpDesc(); | |||
int input_index = op_desc->GetInputIndexByName(input_name); | |||
if (input_index < 0) { | |||
GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.", | |||
node->GetName().c_str(), input_name.c_str()); | |||
REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.", | |||
node->GetName().c_str(), input_name.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
const auto &in_anchor = node->GetInDataAnchor(input_index); | |||
GE_CHECK_NOTNULL(in_anchor); | |||
const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(peer_out_anchor); | |||
const auto &src_node = peer_out_anchor->GetOwnerNode(); | |||
GE_CHECK_NOTNULL(src_node); | |||
auto src_op_desc = src_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(src_op_desc); | |||
if (src_op_desc->GetType() == DATA) { | |||
auto tensor = src_op_desc->MutableInputDesc(kInputIndexOfData); | |||
if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) { | |||
GELOGD("Get hostmem from node %s, inputname: %s.", src_node->GetName().c_str(), input_name.c_str()); | |||
continue; | |||
} | |||
} | |||
flag = false; | |||
return SUCCESS; | |||
} | |||
flag = true; | |||
return SUCCESS; | |||
} | |||
Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | |||
auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | |||
GE_CHECK_NOTNULL(comp_graph); | |||
for (const auto &node : comp_graph->GetAllNodes()) { | |||
GE_CHECK_NOTNULL(node); | |||
auto op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
const auto &depends = op_desc->GetOpInferDepends(); | |||
bool support_dynamic_shape = false; | |||
(void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); | |||
if (!depends.empty() && support_dynamic_shape) { | |||
flag = true; | |||
CheckHostMem(depends, node, flag); | |||
return SUCCESS; | |||
} | |||
} | |||
@@ -92,16 +92,15 @@ TEST_F(UtestHybridModelAsyncExecutor, Test_execute_internal) { | |||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||
ge_root_model->SetModelName("test_name"); | |||
HybridModel hybrid_model(ge_root_model); | |||
hybrid_model.root_graph_item_.reset(new GraphItem); | |||
HybridModelExecutor executor(&hybrid_model, 0, nullptr); | |||
ASSERT_EQ(executor.Init(), SUCCESS); | |||
auto &context = executor.context_; | |||
GraphItem graph_item; | |||
SubgraphExecutor subgraph_executor(&graph_item, &context); | |||
HybridModelExecutor::ExecuteArgs args; | |||
std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> eof_entry; | |||
eof_entry.first = nullptr; | |||
context.callback_manager->callback_queue_.Push(eof_entry); | |||
ASSERT_EQ(executor.ExecuteGraphInternal(subgraph_executor, args), SUCCESS); | |||
ASSERT_EQ(executor.ExecuteGraphInternal(args), SUCCESS); | |||
} | |||
} // namespace ge |
@@ -330,6 +330,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | |||
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | |||
HybridModel model(root_model); | |||
model.root_graph_item_.reset(new GraphItem); | |||
HybridModel *model_ptr = &model; | |||
uint32_t device_id = 0; | |||