Browse Source

!1918 IsEnterFeedNode

Merge pull request !1918 from 梁昊/lh2
tags/v1.5.1
i-robot Gitee 3 years ago
parent
commit
23ad613fdc
1 changed files with 23 additions and 2 deletions
  1. +23
    -2
      ge/hybrid/model/node_item.cc

+ 23
- 2
ge/hybrid/model/node_item.cc View File

@@ -24,6 +24,8 @@
namespace ge {
namespace hybrid {
namespace {
const uint8_t kMaxTransCount = 3;
const uint32_t kTransOpIoSize = 1;
const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
const char *const kNodeTypeRetVal = "_RetVal";
const std::set<std::string> kControlOpTypes{
@@ -39,6 +41,25 @@ const std::set<std::string> kMergeOpTypes{
MERGE, REFMERGE, STREAMMERGE
};

bool IsEnterFeedNode(NodePtr node) {
// For: Enter -> node
// For: Enter -> Cast -> node
// For: Enter -> TransData -> Cast -> node
for (uint8_t i = 0; i < kMaxTransCount; ++i) {
if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
GELOGD("Node[%u] is Enter feed node.", node->GetName().c_str());
return true;
}

const auto all_nodes = node->GetInDataNodes();
if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) {
return false;
}
node = all_nodes.at(0);
}
return false;
}

Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) {
uint32_t parent_index = 0;
if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
@@ -399,7 +420,7 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
data_anchors.emplace(anchor_index);
}
// If Enter feed Not Merge, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) {
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) {
auto &data_anchors = node_item->enter_data_[this];
data_anchors.emplace(anchor_index);
}
@@ -419,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) {
node_item->root_ctrl_.emplace(this);
}
// If Enter feed control signal, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
node_item->enter_ctrl_.emplace(this);
}
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());


Loading…
Cancel
Save