Browse Source

IsEnterFeedNode

tags/v1.5.1
lianghao 3 years ago
parent
commit
70a9868d3b
1 changed files with 23 additions and 2 deletions
  1. +23
    -2
      ge/hybrid/model/node_item.cc

+ 23
- 2
ge/hybrid/model/node_item.cc View File

@@ -24,6 +24,8 @@
namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
namespace { namespace {
const uint8_t kMaxTransCount = 3;
const uint32_t kTransOpIoSize = 1;
const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
const char *const kNodeTypeRetVal = "_RetVal"; const char *const kNodeTypeRetVal = "_RetVal";
const std::set<std::string> kControlOpTypes{ const std::set<std::string> kControlOpTypes{
@@ -39,6 +41,25 @@ const std::set<std::string> kMergeOpTypes{
MERGE, REFMERGE, STREAMMERGE MERGE, REFMERGE, STREAMMERGE
}; };


bool IsEnterFeedNode(NodePtr node) {
// For: Enter -> node
// For: Enter -> Cast -> node
// For: Enter -> TransData -> Cast -> node
for (uint8_t i = 0; i < kMaxTransCount; ++i) {
if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
GELOGD("Node[%u] is Enter feed node.", node->GetName().c_str());
return true;
}

const auto all_nodes = node->GetInDataNodes();
if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) {
return false;
}
node = all_nodes.at(0);
}
return false;
}

Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) {
uint32_t parent_index = 0; uint32_t parent_index = 0;
if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
@@ -399,7 +420,7 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
data_anchors.emplace(anchor_index); data_anchors.emplace(anchor_index);
} }
// If Enter feed Not Merge, take as root Node. // If Enter feed Not Merge, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) {
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) {
auto &data_anchors = node_item->enter_data_[this]; auto &data_anchors = node_item->enter_data_[this];
data_anchors.emplace(anchor_index); data_anchors.emplace(anchor_index);
} }
@@ -419,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) {
node_item->root_ctrl_.emplace(this); node_item->root_ctrl_.emplace(this);
} }
// If Enter feed control signal, take as root Node. // If Enter feed control signal, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
node_item->enter_ctrl_.emplace(this); node_item->enter_ctrl_.emplace(this);
} }
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());


Loading…
Cancel
Save