From: @xchu42 Reviewed-by: @ji_chen,@wqtshg Signed-off-by: @ji_chentags/v1.1.0
@@ -257,7 +257,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
} | } | ||||
// cond or branch need to be prepared before the execution of IF or CASE | // cond or branch need to be prepared before the execution of IF or CASE | ||||
if (node_item.node_type == IF || node_item.node_type == CASE) { | |||||
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | |||||
const auto &in_anchor = ge_node->GetInDataAnchor(0); | const auto &in_anchor = ge_node->GetInDataAnchor(0); | ||||
GE_CHECK_NOTNULL(in_anchor); | GE_CHECK_NOTNULL(in_anchor); | ||||
const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); | const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); | ||||
@@ -920,7 +920,7 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr | |||||
auto parent_node = sub_graph.GetParentNode(); | auto parent_node = sub_graph.GetParentNode(); | ||||
GE_CHECK_NOTNULL(parent_node); | GE_CHECK_NOTNULL(parent_node); | ||||
auto op_type = parent_node->GetType(); | auto op_type = parent_node->GetType(); | ||||
if (op_type == IF || op_type == CASE || op_type == WHILE) { | |||||
if (IsControlOp(op_type)) { | |||||
GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", | GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", | ||||
sub_graph.GetName().c_str(), | sub_graph.GetName().c_str(), | ||||
ge_model->GetModelTaskDefPtr()->task_size()); | ge_model->GetModelTaskDefPtr()->task_size()); | ||||
@@ -28,6 +28,9 @@ namespace hybrid { | |||||
namespace { | namespace { | ||||
const char * const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | const char * const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | ||||
const char * const kNodeTypeRetVal = "_RetVal"; | const char * const kNodeTypeRetVal = "_RetVal"; | ||||
std::set<std::string> kControlOpTypes { | |||||
IF, STATELESSIF, CASE, WHILE, STATELESSWHILE | |||||
}; | |||||
Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | ||||
uint32_t parent_index = 0; | uint32_t parent_index = 0; | ||||
@@ -102,6 +105,11 @@ Status ParseFusedSubgraph(NodeItem &node_item) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace | } // namespace | ||||
bool IsControlOp(const std::string &op_type) { | |||||
return kControlOpTypes.count(op_type) > 0; | |||||
} | |||||
NodeItem::NodeItem(NodePtr node): node(std::move(node)) { | NodeItem::NodeItem(NodePtr node): node(std::move(node)) { | ||||
this->op_desc = this->node->GetOpDesc().get(); | this->op_desc = this->node->GetOpDesc().get(); | ||||
this->node_id = this->op_desc->GetId(); | this->node_id = this->op_desc->GetId(); | ||||
@@ -153,8 +161,7 @@ Status NodeItem::Init() { | |||||
} | } | ||||
bool NodeItem::IsControlOp() const { | bool NodeItem::IsControlOp() const { | ||||
auto op_type = op_desc->GetType(); | |||||
return op_type == IF || op_type == CASE || op_type == WHILE || op_type == FOR; | |||||
return ge::hybrid::IsControlOp(op_desc->GetType()); | |||||
} | } | ||||
std::string NodeItem::DebugString() const { | std::string NodeItem::DebugString() const { | ||||
@@ -36,6 +36,8 @@ struct FusedSubgraph { | |||||
ComputeGraphPtr graph; | ComputeGraphPtr graph; | ||||
}; | }; | ||||
bool IsControlOp(const std::string &op_type); | |||||
// for caching static information across execution | // for caching static information across execution | ||||
struct NodeItem { | struct NodeItem { | ||||
explicit NodeItem(NodePtr node); | explicit NodeItem(NodePtr node); | ||||
@@ -404,11 +404,11 @@ Status ControlOpNodeExecutor::LoadTask(const HybridModel &model, | |||||
unique_ptr<ControlOpNodeTask> node_task; | unique_ptr<ControlOpNodeTask> node_task; | ||||
auto node_type = node->GetType(); | auto node_type = node->GetType(); | ||||
if (node_type == IF) { | |||||
if (node_type == IF || node_type == STATELESSIF) { | |||||
node_task.reset(new(std::nothrow) IfOpNodeTask()); | node_task.reset(new(std::nothrow) IfOpNodeTask()); | ||||
} else if (node_type == CASE) { | } else if (node_type == CASE) { | ||||
node_task.reset(new(std::nothrow) CaseOpNodeTask()); | node_task.reset(new(std::nothrow) CaseOpNodeTask()); | ||||
} else if (node_type == WHILE) { | |||||
} else if (node_type == WHILE || node_type == STATELESSWHILE) { | |||||
node_task.reset(new(std::nothrow) WhileOpNodeTask()); | node_task.reset(new(std::nothrow) WhileOpNodeTask()); | ||||
} else { | } else { | ||||
GELOGE(PARAM_INVALID, "[%s] Unsupported type: %s", node->GetName().c_str(), node_type.c_str()); | GELOGE(PARAM_INVALID, "[%s] Unsupported type: %s", node->GetName().c_str(), node_type.c_str()); | ||||
@@ -97,7 +97,7 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node | |||||
return ExecutorType::GE_LOCAL; | return ExecutorType::GE_LOCAL; | ||||
} | } | ||||
if (op_type == IF || op_type == CASE || op_type == WHILE) { | |||||
if (IsControlOp(op_type)) { | |||||
return ExecutorType::CONTROL_OP; | return ExecutorType::CONTROL_OP; | ||||
} | } | ||||