diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index f47a02fd..b39975bc 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -257,7 +257,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s } // cond or branch need to be prepared before the execution of IF or CASE - if (node_item.node_type == IF || node_item.node_type == CASE) { + if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { const auto &in_anchor = ge_node->GetInDataAnchor(0); GE_CHECK_NOTNULL(in_anchor); const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); @@ -917,7 +917,7 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr auto parent_node = sub_graph.GetParentNode(); GE_CHECK_NOTNULL(parent_node); auto op_type = parent_node->GetType(); - if (op_type == IF || op_type == CASE || op_type == WHILE) { + if (IsControlOp(op_type)) { GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", sub_graph.GetName().c_str(), ge_model->GetModelTaskDefPtr()->task_size()); diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 4a019487..a105aaaf 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -28,6 +28,9 @@ namespace hybrid { namespace { const char * const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; const char * const kNodeTypeRetVal = "_RetVal"; +std::set kControlOpTypes { + IF, STATELESSIF, CASE, WHILE, STATELESSWHILE +} Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { uint32_t parent_index = 0; @@ -102,6 +105,11 @@ Status ParseFusedSubgraph(NodeItem &node_item) { return SUCCESS; } } // namespace + +bool IsControlOp(const std::string &op_type) { + return kControlOpTypes.count(op_type) > 0; +} + NodeItem::NodeItem(NodePtr node): node(std::move(node)) { this->op_desc = this->node->GetOpDesc().get(); this->node_id = this->op_desc->GetId(); @@ -153,8 +161,7 @@ Status NodeItem::Init() { } bool NodeItem::IsControlOp() const { - auto op_type = op_desc->GetType(); - return op_type == IF || op_type == CASE || op_type == WHILE || op_type == FOR; + return ge::hybrid::IsControlOp(op_desc->GetType()); } std::string NodeItem::DebugString() const { diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index c10cf13e..a59c4dc1 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -36,6 +36,8 @@ struct FusedSubgraph { ComputeGraphPtr graph; }; +bool IsControlOp(const std::string &op_type); + // for caching static information across execution struct NodeItem { explicit NodeItem(NodePtr node); diff --git a/ge/hybrid/node_executor/controlop/control_op_executor.cc b/ge/hybrid/node_executor/controlop/control_op_executor.cc index 5f9dde2a..88c82e6e 100644 --- a/ge/hybrid/node_executor/controlop/control_op_executor.cc +++ b/ge/hybrid/node_executor/controlop/control_op_executor.cc @@ -404,11 +404,11 @@ Status ControlOpNodeExecutor::LoadTask(const HybridModel &model, unique_ptr node_task; auto node_type = node->GetType(); - if (node_type == IF) { + if (node_type == IF || node_type == STATELESSIF) { node_task.reset(new(std::nothrow) IfOpNodeTask()); } else if (node_type == CASE) { node_task.reset(new(std::nothrow) CaseOpNodeTask()); - } else if (node_type == WHILE) { + } else if (node_type == WHILE || node_type == STATELESSWHILE) { node_task.reset(new(std::nothrow) WhileOpNodeTask()); } else { GELOGE(PARAM_INVALID, "[%s] Unsupported type: %s", node->GetName().c_str(), node_type.c_str()); diff --git a/ge/hybrid/node_executor/node_executor.cc b/ge/hybrid/node_executor/node_executor.cc index fdfdfb51..7f11c20a 100755 --- a/ge/hybrid/node_executor/node_executor.cc +++ b/ge/hybrid/node_executor/node_executor.cc @@ -97,7 +97,7 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node return ExecutorType::GE_LOCAL; } - if (op_type == IF || op_type == CASE || op_type == WHILE) { + if (IsControlOp(op_type)) { return ExecutorType::CONTROL_OP; }