| @@ -274,21 +274,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { | |||
| return false; | |||
| } | |||
| /// | |||
| /// @brief Set Op _force_unknown_shape flag | |||
| /// @param [in] node | |||
| /// @param [in] force_unknown, set attribute if true | |||
| /// @param [in] group_index, condition group index of node. | |||
| /// @return | |||
| /// | |||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { | |||
| if (!force_unknown) { | |||
| return; | |||
| } | |||
| SetControlFlowGroup(node, group_index); | |||
| } | |||
| /// | |||
| /// @brief Set Op _control_flow_group flag | |||
| /// @param [in] node | |||
| @@ -125,15 +125,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size); | |||
| /// | |||
| bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | |||
| /// | |||
| /// @brief Set Op _force_unknown_shape flag | |||
| /// @param [in] node | |||
| /// @param [in] force_unknown, set attribute if true | |||
| /// @param [in] group_index, condition group index of node. | |||
| /// @return | |||
| /// | |||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | |||
| /// | |||
| /// @brief Set Op _control_flow_group flag | |||
| /// @param [in] node | |||
| @@ -132,38 +132,18 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||
| /// @return | |||
| /// | |||
| void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | |||
| std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); | |||
| }; | |||
| for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { | |||
| const auto &op_node1 = it1->first; | |||
| const auto &op_desc1 = op_node1->GetOpDesc(); | |||
| if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
| for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) { | |||
| const auto &op_node = it->first; | |||
| const auto &op_desc = op_node->GetOpDesc(); | |||
| if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
| continue; | |||
| } | |||
| int64_t group_index = op_desc1->GetId(); | |||
| GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); | |||
| SetControlFlowGroup(op_node1, group_index); | |||
| for (const auto &n : it1->second) { | |||
| int64_t group_index = op_desc->GetId(); | |||
| SetControlFlowGroup(op_node, group_index); | |||
| for (const auto &n : it->second) { | |||
| SetControlFlowGroup(n, group_index); | |||
| } | |||
| for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { | |||
| const auto &op_node2 = it2->first; | |||
| const auto &op_desc2 = op_node2->GetOpDesc(); | |||
| if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
| continue; | |||
| } | |||
| if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { | |||
| SetControlFlowGroup(op_node2, group_index); | |||
| for (const auto &n : it2->second) { | |||
| SetControlFlowGroup(n, group_index); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace ge | |||
| @@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||
| REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | |||
| return FAILED, "[Check][Param] Param of pre node is nullptr."); | |||
| int64_t group_index = -1; | |||
| bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| MarkForceUnknownShape(node, force_unknown, group_index); | |||
| (void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | |||
| @@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||
| GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| MarkForceUnknownShape(active_node, force_unknown, group_index); | |||
| SetControlFlowGroup(active_node, group_index); | |||
| } | |||
| return SUCCESS; | |||
| @@ -395,8 +395,8 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||
| peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | |||
| int64_t group_index = -1; | |||
| bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| MarkForceUnknownShape(stream_switch, force_unknown, group_index); | |||
| (void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| SetControlFlowGroup(stream_switch, group_index); | |||
| return stream_switch; | |||
| } | |||
| @@ -491,8 +491,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||
| Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { | |||
| for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { | |||
| for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { | |||
| std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||
| std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||
| const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||
| const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||
| std::set<NodePtr> same_cond_switch; | |||
| same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); | |||
| same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | |||
| @@ -524,13 +524,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
| std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { | |||
| return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| }; | |||
| bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||
| MarkForceUnknownShape(active_node, is_unknown_shape, group_index); | |||
| (void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||
| SetControlFlowGroup(active_node, group_index); | |||
| const std::string &cond_group = cond_node->GetName(); | |||
| for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | |||
| bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); | |||
| std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||
| const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||
| GE_IF_BOOL_EXEC(switch_list.empty(), continue); | |||
| // select first stream_switch | |||
| @@ -559,7 +559,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
| "[Add][Edge] between %s and %s failed.", | |||
| cast_node->GetName().c_str(), stream_switch->GetName().c_str()); | |||
| MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); | |||
| SetControlFlowGroup(stream_switch, group_index); | |||
| for (const NodePtr &node : switch_list) { | |||
| GE_IF_BOOL_EXEC(node != stream_switch, { | |||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | |||
| @@ -317,9 +317,9 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||
| return task_context_; | |||
| } | |||
| void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) { | |||
| void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | |||
| if (node_item_->root_data_.count(input_idx) > 0) { | |||
| GELOGD("[%s] Save Const input tensor: %d", GetName().c_str(), input_idx); | |||
| GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||
| root_tensor_values_[input_idx] = tensor; | |||
| } | |||
| @@ -329,7 +329,7 @@ void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) { | |||
| } | |||
| } | |||
| void NodeState::UpdateRootTensor(int input_idx) { | |||
| void NodeState::UpdatePersistTensor(int input_idx) { | |||
| const auto it = root_tensor_values_.find(input_idx); | |||
| if (it == root_tensor_values_.end()) { | |||
| GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); | |||
| @@ -355,14 +355,14 @@ void NodeState::ResetContext(uint64_t iteration) { | |||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
| for (auto item : node_item_->root_data_) { | |||
| UpdateRootTensor(item.first); | |||
| UpdatePersistTensor(item.first); | |||
| } | |||
| if (iteration > 0) { | |||
| data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | |||
| ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | |||
| for (auto item : node_item_->enter_data_) { | |||
| UpdateRootTensor(item.first); | |||
| UpdatePersistTensor(item.first); | |||
| } | |||
| } | |||
| @@ -129,7 +129,7 @@ struct NodeState { | |||
| void RunStreamActive(); | |||
| void RunNextIteration(); | |||
| void SaveRootTensor(int input_idx, const TensorValue &tensor); | |||
| void SavePersistTensor(int input_idx, const TensorValue &tensor); | |||
| Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | |||
| @@ -189,7 +189,7 @@ struct NodeState { | |||
| void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||
| void ResetContext(uint64_t iteration); | |||
| void ScheduleContext(const NodeState &node_state); | |||
| void UpdateRootTensor(int input_idx); | |||
| void UpdatePersistTensor(int input_idx); | |||
| const NodeItem *node_item_ = nullptr; | |||
| std::shared_ptr<NodeTask> kernel_task_ = nullptr; | |||
| @@ -461,7 +461,7 @@ Status TaskContext::PropagateOutputs() { | |||
| auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | |||
| GE_CHECK_NOTNULL(dst_node_state); | |||
| dst_node_state->SaveRootTensor(dst_input_idx, *tensor); | |||
| dst_node_state->SavePersistTensor(dst_input_idx, *tensor); | |||
| } | |||
| } | |||
| (void)guard; | |||