Browse Source

update

tags/v1.2.0
chuxing 3 years ago
parent
commit
68bbf9e41c
4 changed files with 6 additions and 30 deletions
  1. +1
    -0
      ge/hybrid/executor/hybrid_model_executor.cc
  2. +1
    -0
      ge/hybrid/model/node_item.h
  3. +2
    -26
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc
  4. +2
    -4
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h

+ 1
- 0
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -95,6 +95,7 @@ Status HybridModelExecutor::InitExecutionContext() {


context_.stream = stream_; context_.stream = stream_;
context_.model = model_; context_.model = model_;
context_.is_eos_ = false;
context_.session_id = ::ge::GetContext().SessionId(); context_.session_id = ::ge::GetContext().SessionId();
context_.ge_context = &GetThreadLocalContext(); context_.ge_context = &GetThreadLocalContext();
GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id);


+ 1
- 0
ge/hybrid/model/node_item.h View File

@@ -82,6 +82,7 @@ struct NodeItem {
bool has_observer = false; bool has_observer = false;
bool has_optional_inputs = false; bool has_optional_inputs = false;
bool is_output_shape_static = true; bool is_output_shape_static = true;
bool may_trigger_eos_ = false;
UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE; UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE;
std::string node_name; std::string node_name;
std::string node_type; std::string node_type;


+ 2
- 26
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc View File

@@ -21,8 +21,6 @@
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "graph/attr_value.h" #include "graph/attr_value.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/load/new_model_manager/model_utils.h" #include "graph/load/new_model_manager/model_utils.h"
#include "graph/load/new_model_manager/model_manager.h" #include "graph/load/new_model_manager/model_manager.h"
#include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/executor/hybrid_execution_context.h"
@@ -31,7 +29,7 @@ namespace ge {
namespace hybrid { namespace hybrid {
REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH, KnownNodeExecutor); REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH, KnownNodeExecutor);


Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
Status KnownNodeTask:: ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] Start"); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] Start");
GELOGD("[%s] KnownNodeTask::ExecuteAsync in.", context.GetNodeName()); GELOGD("[%s] KnownNodeTask::ExecuteAsync in.", context.GetNodeName());
if (davinci_model_->GetTaskList().empty()) { if (davinci_model_->GetTaskList().empty()) {
@@ -60,10 +58,6 @@ Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> d
GELOGE(rt_ret, "rtModelExecute error, ret: hybrid_model_executorOx%X", rt_ret); return FAILED;); GELOGE(rt_ret, "rtModelExecute error, ret: hybrid_model_executorOx%X", rt_ret); return FAILED;);
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodertModelExecute] End"); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodertModelExecute] End");


if (need_sync_) {
GELOGD("[%s] model need sync", context.GetNodeName());
GE_CHK_STATUS_RET_NOLOG(context.Synchronize());
}
GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback));
GELOGD("[%s] KnownNodeTask::ExecuteAsync success.", context.GetNodeName()); GELOGD("[%s] KnownNodeTask::ExecuteAsync success.", context.GetNodeName());
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] End"); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] End");
@@ -177,9 +171,7 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node


GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed."); GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed.");


bool need_sync = false;
GE_CHK_STATUS_RET_NOLOG(NeedSync(*ge_model, need_sync));
task = MakeShared<KnownNodeTask>(davinci_model, need_sync);
task = MakeShared<KnownNodeTask>(davinci_model);
GE_CHECK_NOTNULL(task); GE_CHECK_NOTNULL(task);
GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str()); GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str());
return SUCCESS; return SUCCESS;
@@ -194,21 +186,5 @@ Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context,
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End"); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End");
return SUCCESS; return SUCCESS;
} }

Status KnownNodeExecutor::NeedSync(GeModel &ge_model, bool &need_sync) {
auto compute_graph = GraphUtils::GetComputeGraph(ge_model.GetGraph());
GE_CHECK_NOTNULL(compute_graph);
for (auto &node : compute_graph->GetAllNodes()) {
auto type = NodeUtils::GetNodeType(node);
if (type == GETNEXT) {
GELOGD("Contains GetNext node: %s", node->GetName().c_str());
need_sync = true;
return SUCCESS;
}
}

need_sync = false;
return SUCCESS;
}
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge

+ 2
- 4
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h View File

@@ -27,8 +27,8 @@ class HybridModel;


class KnownNodeTask : public NodeTask { class KnownNodeTask : public NodeTask {
public: public:
explicit KnownNodeTask(std::shared_ptr<DavinciModel> davinci_model, bool need_sync)
: davinci_model_(davinci_model), need_sync_(need_sync)
explicit KnownNodeTask(std::shared_ptr<DavinciModel> davinci_model)
: davinci_model_(davinci_model)
{} {}


~KnownNodeTask() {} ~KnownNodeTask() {}
@@ -39,7 +39,6 @@ class KnownNodeTask : public NodeTask {
private: private:
std::shared_ptr<DavinciModel> davinci_model_ = nullptr; std::shared_ptr<DavinciModel> davinci_model_ = nullptr;
bool load_flag_ = false; bool load_flag_ = false;
bool need_sync_;
}; };


class KnownNodeExecutor : public NodeExecutor { class KnownNodeExecutor : public NodeExecutor {
@@ -49,7 +48,6 @@ class KnownNodeExecutor : public NodeExecutor {
Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const; Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const;
~KnownNodeExecutor() {} ~KnownNodeExecutor() {}
private: private:
static Status NeedSync(GeModel &ge_model, bool &need_sync);
std::shared_ptr<DavinciModel> davinci_model_ = nullptr; std::shared_ptr<DavinciModel> davinci_model_ = nullptr;
}; };
} // namespace hybrid } // namespace hybrid


Loading…
Cancel
Save