| @@ -24,6 +24,8 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace { | namespace { | ||||
| const uint8_t kMaxTransCount = 3; | |||||
| const uint32_t kTransOpIoSize = 1; | |||||
| const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | ||||
| const char *const kNodeTypeRetVal = "_RetVal"; | const char *const kNodeTypeRetVal = "_RetVal"; | ||||
| const std::set<std::string> kControlOpTypes{ | const std::set<std::string> kControlOpTypes{ | ||||
| @@ -39,6 +41,25 @@ const std::set<std::string> kMergeOpTypes{ | |||||
| MERGE, REFMERGE, STREAMMERGE | MERGE, REFMERGE, STREAMMERGE | ||||
| }; | }; | ||||
| bool IsEnterFeedNode(NodePtr node) { | |||||
| // For: Enter -> node | |||||
| // For: Enter -> Cast -> node | |||||
| // For: Enter -> TransData -> Cast -> node | |||||
| for (uint8_t i = 0; i < kMaxTransCount; ++i) { | |||||
| if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { | |||||
| GELOGD("Node[%u] is Enter feed node.", node->GetName().c_str()); | |||||
| return true; | |||||
| } | |||||
| const auto all_nodes = node->GetInDataNodes(); | |||||
| if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) { | |||||
| return false; | |||||
| } | |||||
| node = all_nodes.at(0); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | ||||
| uint32_t parent_index = 0; | uint32_t parent_index = 0; | ||||
| if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | ||||
| @@ -399,7 +420,7 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
| data_anchors.emplace(anchor_index); | data_anchors.emplace(anchor_index); | ||||
| } | } | ||||
| // If Enter feed Not Merge, take as root Node. | // If Enter feed Not Merge, take as root Node. | ||||
| if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | |||||
| if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) { | |||||
| auto &data_anchors = node_item->enter_data_[this]; | auto &data_anchors = node_item->enter_data_[this]; | ||||
| data_anchors.emplace(anchor_index); | data_anchors.emplace(anchor_index); | ||||
| } | } | ||||
| @@ -419,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | |||||
| node_item->root_ctrl_.emplace(this); | node_item->root_ctrl_.emplace(this); | ||||
| } | } | ||||
| // If Enter feed control signal, take as root Node. | // If Enter feed control signal, take as root Node. | ||||
| if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { | |||||
| if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { | |||||
| node_item->enter_ctrl_.emplace(this); | node_item->enter_ctrl_.emplace(this); | ||||
| } | } | ||||
| GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||