|
|
@@ -21,8 +21,6 @@ |
|
|
|
#include "common/ge/ge_util.h" |
|
|
|
#include "graph/attr_value.h" |
|
|
|
#include "graph/debug/ge_attr_define.h" |
|
|
|
#include "graph/utils/graph_utils.h" |
|
|
|
#include "graph/utils/node_utils.h" |
|
|
|
#include "graph/load/new_model_manager/model_utils.h" |
|
|
|
#include "graph/load/new_model_manager/model_manager.h" |
|
|
|
#include "hybrid/executor/hybrid_execution_context.h" |
|
|
@@ -31,7 +29,7 @@ namespace ge { |
|
|
|
namespace hybrid { |
|
|
|
REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH, KnownNodeExecutor); |
|
|
|
|
|
|
|
Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { |
|
|
|
Status KnownNodeTask:: ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { |
|
|
|
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] Start"); |
|
|
|
GELOGD("[%s] KnownNodeTask::ExecuteAsync in.", context.GetNodeName()); |
|
|
|
if (davinci_model_->GetTaskList().empty()) { |
|
|
@@ -60,10 +58,6 @@ Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> d |
|
|
|
GELOGE(rt_ret, "rtModelExecute error, ret: hybrid_model_executorOx%X", rt_ret); return FAILED;); |
|
|
|
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodertModelExecute] End"); |
|
|
|
|
|
|
|
if (need_sync_) { |
|
|
|
GELOGD("[%s] model need sync", context.GetNodeName()); |
|
|
|
GE_CHK_STATUS_RET_NOLOG(context.Synchronize()); |
|
|
|
} |
|
|
|
GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); |
|
|
|
GELOGD("[%s] KnownNodeTask::ExecuteAsync success.", context.GetNodeName()); |
|
|
|
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] End"); |
|
|
@@ -177,9 +171,7 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node |
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed."); |
|
|
|
|
|
|
|
bool need_sync = false; |
|
|
|
GE_CHK_STATUS_RET_NOLOG(NeedSync(*ge_model, need_sync)); |
|
|
|
task = MakeShared<KnownNodeTask>(davinci_model, need_sync); |
|
|
|
task = MakeShared<KnownNodeTask>(davinci_model); |
|
|
|
GE_CHECK_NOTNULL(task); |
|
|
|
GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str()); |
|
|
|
return SUCCESS; |
|
|
@@ -194,21 +186,5 @@ Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, |
|
|
|
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End"); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status KnownNodeExecutor::NeedSync(GeModel &ge_model, bool &need_sync) { |
|
|
|
auto compute_graph = GraphUtils::GetComputeGraph(ge_model.GetGraph()); |
|
|
|
GE_CHECK_NOTNULL(compute_graph); |
|
|
|
for (auto &node : compute_graph->GetAllNodes()) { |
|
|
|
auto type = NodeUtils::GetNodeType(node); |
|
|
|
if (type == GETNEXT) { |
|
|
|
GELOGD("Contains GetNext node: %s", node->GetName().c_str()); |
|
|
|
need_sync = true; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
need_sync = false; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} // namespace hybrid |
|
|
|
} // namespace ge |