Browse Source

Fix root feed inner Enter

tags/v1.3.0
zhangxiaokun 3 years ago
parent
commit
8ba51653b3
2 changed files with 7 additions and 3 deletions
  1. +2
    -0
      ge/hybrid/executor/subgraph_context.cc
  2. +5
    -3
      ge/hybrid/model/hybrid_model_builder.cc

+ 2
- 0
ge/hybrid/executor/subgraph_context.cc View File

@@ -106,6 +106,8 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
FrameStatePtr SubgraphContext::GetOrCreateFrameState(const NodeItem &node_item) {
auto &frame_state = frame_states_[node_item.frame_index_];
if (frame_state == nullptr) {
GELOGD("[%s] Create FrameState, frame index: %ld, parent frame index: %ld",
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_);
frame_state.reset(new(std::nothrow)FrameState(node_item.frame_index_));
if (node_item.frame_index_ != -1) { // -1 is root frame.
frame_state->parent_frame_ = frame_states_[node_item.parent_frame_];


+ 5
- 3
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -2358,13 +2358,15 @@ Status HybridModelBuilder::BuildFrameGroupIndex(NodeItem &node_item) {
int64_t ctrl_flow_group = -1;
if (node_item.IsEnterOp() && AttrUtils::GetInt(node_item.op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, ctrl_flow_group)) {
node_item.frame_index_ = ctrl_flow_group;
if (node_item.IsEnterOp()) {
const auto src_node = node_item.node->GetInDataNodes().at(0);
for (const auto src_node : node_item.node->GetInDataNodes()) {
NodeItem *src_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item),
"[%s] failed to get or create node item", src_node->GetName().c_str());
if (!src_node_item->is_root_node_) {
GELOGD("[%s] frame index: %ld, from [%s] get parent frame index: %ld", node_item.node_name.c_str(),
node_item.frame_index_, src_node_item->node_name.c_str(), src_node_item->frame_index_);
parent_frame_group_[node_item.frame_index_] = src_node_item->frame_index_;
break;
}
}

@@ -2390,7 +2392,7 @@ Status HybridModelBuilder::BuildFrameGroupIndex(NodeItem &node_item) {
node_item.frame_index_ = src_node_item->frame_index_;
}

const auto it = parent_frame_group_.find(src_node_item->frame_index_);
const auto it = parent_frame_group_.find(node_item.frame_index_);
node_item.parent_frame_ = (it != parent_frame_group_.end()) ? it->second : -1;
GELOGD("[%s] control flow frame group: %ld, parent frame: %ld",
node_item.node_name.c_str(), node_item.frame_index_, node_item.parent_frame_);


Loading…
Cancel
Save