diff --git a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc index 6e0edfc5..f6c0cf79 100755 --- a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc +++ b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc @@ -39,6 +39,10 @@ const std::map> const std::set DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE}; +const std::set ConstantNodeTask::constant_like_task_ops_ = {CONSTANT, CONSTANTOP, VARIABLE}; + +const std::set NoOpNodeTask::control_only_task_ops_ = {NOOP, CONTROLTRIGGER}; + Status RefInputTask::UpdateArgs(TaskContext &) { // no need update args return SUCCESS; @@ -244,7 +248,7 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model, node->GetName().c_str(), node_type.c_str()); return MEMALLOC_FAILED; } - } else if (node_type == CONSTANT || node_type == CONSTANTOP || node_type == VARIABLE) { + } else if (ConstantNodeTask::IsBelong(node_type)) { GELOGI("node %s type %s, use ConstantNodeTask.", node->GetName().c_str(), node_type.c_str()); auto tensor = model.GetTensor(node); if (tensor == nullptr) { @@ -254,7 +258,7 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model, } task = MakeShared(tensor); GE_CHECK_NOTNULL(task); - } else if (node_type == NOOP) { + } else if (NoOpNodeTask::IsBelong(node_type)) { GELOGI("node %s type %s , use NoOpNodeTask.", node->GetName().c_str(), node_type.c_str()); task = MakeShared(); if (task == nullptr) { @@ -288,6 +292,10 @@ Status ConstantNodeTask::ExecuteAsync(TaskContext &context, std::function 0; +} + Status NoOpNodeTask::UpdateArgs(TaskContext &context) { // no need to update args return SUCCESS; @@ -299,5 +307,9 @@ Status NoOpNodeTask::ExecuteAsync(TaskContext &context, std::function do GELOGD("[%s] Done execute successfully.", context.GetNodeName()); return SUCCESS; } + +bool NoOpNodeTask::IsBelong(const std::string &op_type) { + return control_only_task_ops_.count(op_type) > 0; +} } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.h b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.h index 948809f8..4b734ef3 100644 --- a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.h +++ b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.h @@ -75,8 +75,10 @@ class ConstantNodeTask : public NodeTask { Status UpdateArgs(TaskContext &context) override; Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + static bool IsBelong(const std::string &op_type); private: + static const std::set constant_like_task_ops_; const TensorValue *tensor_; }; @@ -86,6 +88,10 @@ class NoOpNodeTask : public NodeTask { ~NoOpNodeTask() = default; Status UpdateArgs(TaskContext &context) override; Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + static bool IsBelong(const std::string &op_type); + + private: + static const std::set control_only_task_ops_; }; class GeLocalNodeExecutor : public NodeExecutor {