From 7337e3ceb2299c38af1af2c2caac54f9f04df5ee Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Mon, 17 May 2021 20:25:47 +0800 Subject: [PATCH] Fix DSP for v1 control flow --- ge/graph/common/omg_util.cc | 20 ++++-- ge/graph/common/omg_util.h | 3 +- ge/graph/manager/graph_manager.cc | 3 - ge/graph/partition/dynamic_shape_partition.cc | 60 +++++++++++------- ge/graph/partition/dynamic_shape_partition.h | 9 ++- .../mark_force_unknown_for_cond_pass.cc | 63 ++++++++++++++----- ge/graph/passes/merge_to_stream_merge_pass.cc | 7 ++- ge/graph/passes/next_iteration_pass.cc | 25 ++++---- ge/graph/passes/next_iteration_pass.h | 3 +- .../passes/switch_to_stream_switch_pass.cc | 13 ++-- ge/graph/preprocess/graph_preprocess.cc | 13 ++++ ge/graph/preprocess/graph_preprocess.h | 1 + ge/hybrid/executor/node_state.cc | 41 ++++++++++-- ge/hybrid/executor/node_state.h | 2 + 14 files changed, 190 insertions(+), 73 deletions(-) diff --git a/ge/graph/common/omg_util.cc b/ge/graph/common/omg_util.cc index 1dba8c51..15fa3c47 100644 --- a/ge/graph/common/omg_util.cc +++ b/ge/graph/common/omg_util.cc @@ -272,20 +272,32 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { /// @brief Set Op _force_unknown_shape flag /// @param [in] node /// @param [in] force_unknown, set attribute if true +/// @param [in] group_index, condition group index of node. /// @return /// -void MarkForceUnknownShape(const NodePtr &node, bool force_unknown) { - GE_RT_VOID_CHECK_NOTNULL(node); +void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { if (!force_unknown) { return; } - GELOGD("[%s] mark as force unknown shape node", node->GetName().c_str()); - if (!AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) { + GE_RT_VOID_CHECK_NOTNULL(node); + const auto &op_desc = node->GetOpDesc(); + GE_RT_VOID_CHECK_NOTNULL(op_desc); + + // op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check. + GELOGD("Mark [%s] as force unknown shape node, group index: %ld", node->GetName().c_str(), group_index); + if (!AttrUtils::SetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) { REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), node->GetName().c_str(), node->GetType().c_str()); GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), node->GetName().c_str(), node->GetType().c_str()); } + + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { + REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), + node->GetName().c_str(), node->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), + node->GetName().c_str(), node->GetType().c_str()); + } } } // namespace ge diff --git a/ge/graph/common/omg_util.h b/ge/graph/common/omg_util.h index c84da7f8..fdb0e138 100644 --- a/ge/graph/common/omg_util.h +++ b/ge/graph/common/omg_util.h @@ -129,9 +129,10 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); /// @brief Set Op _force_unknown_shape flag /// @param [in] node /// @param [in] force_unknown, set attribute if true +/// @param [in] group_index, condition group index of node. /// @return /// -void MarkForceUnknownShape(const NodePtr &node, bool force_unknown); +void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); } // namespace ge #endif // GE_GRAPH_COMMON_OMG_UTIL_H_ diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 465ae749..69c84f6f 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -65,7 +65,6 @@ #include "graph/passes/merge_pass.h" #include "graph/passes/merge_input_memcpy_pass.h" #include "graph/passes/merge_to_stream_merge_pass.h" -#include "graph/passes/mark_force_unknown_for_cond_pass.h" #include "graph/passes/multi_batch_pass.h" #include "graph/passes/next_iteration_pass.h" #include "graph/passes/permute_pass.h" @@ -2582,8 +2581,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)); - auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; - GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkForceUnknownForCondPass", mark_force_unknown_pass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) GE_CHK_STATUS_RET( diff --git a/ge/graph/partition/dynamic_shape_partition.cc b/ge/graph/partition/dynamic_shape_partition.cc index 516d06d1..8fee1eb5 100755 --- a/ge/graph/partition/dynamic_shape_partition.cc +++ b/ge/graph/partition/dynamic_shape_partition.cc @@ -46,11 +46,6 @@ #define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) namespace ge { -namespace { -const std::set kControlFlowOps{ - STREAMACTIVE, STREAMSWITCH, STREAMMERGE, ENTER, REFENTER, LOOPCOND, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT -}; -} using Cluster = DynamicShapePartitioner::Cluster; using ClusterPtr = std::shared_ptr; @@ -279,9 +274,17 @@ Status DynamicShapePartitioner::InitClusters() { auto cluster = MakeShared(rank++, type, node, this); REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); node_2_cluster_[node] = cluster; - if (cluster->IsUnknownShape() && !cluster->IsControlFlow()) { + if (cluster->IsUnknownShape()) { ordered_cluster_.push_back(cluster); } + + int64_t group_index = -1; + if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { + GELOGD("[%s] is rts control flow Op, group index: %ld", node->GetName().c_str(), group_index); + auto &control_cluster = control_clusters_[group_index]; + control_cluster.emplace_back(cluster); + } + // Already sorted topologically, so access to the parent cluster is safe for (const auto &parent : node->GetInAllNodes()) { cluster->AddInput(node_2_cluster_[parent]); @@ -350,14 +353,38 @@ static std::string ToString(const std::vector &clusters) { } } +void DynamicShapePartitioner::MergeClustersControlFlow() { + for (const auto &item : control_clusters_) { + const auto &control_cluster = item.second; + auto rit = control_cluster.rbegin(); + if (rit == control_cluster.rend()) { + GELOGW("Invalid empty control flow cluster."); + continue; + } + + const auto &cluster = *rit; + for (++rit; rit != control_cluster.rend(); ++rit) { + const auto &cluster_from = *rit; + auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); + GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), + ToString(merged_clusters).c_str()); + for (const auto &merged_cluster : merged_clusters) { + for (const auto &node : merged_cluster->Nodes()) { + node_2_cluster_[node] = cluster; + } + } + } + } +} + void DynamicShapePartitioner::MergeClustersUnknownShape() { // Merge unknown shape clusters for (const auto &cluster : ordered_cluster_) { - if (cluster->IsIndependent() || cluster->IsControlFlow()) { + if (cluster->IsIndependent()) { continue; } for (const auto &in_cluster : cluster->Inputs()) { - if (!in_cluster->IsUnknownShape() || in_cluster->IsControlFlow()) { + if (!in_cluster->IsUnknownShape()) { continue; } auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); @@ -419,6 +446,7 @@ void DynamicShapePartitioner::MergeClustersInputData() { } Status DynamicShapePartitioner::MergeClusters() { + MergeClustersControlFlow(); MergeClustersUnknownShape(); REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); MergeClustersKnownShape(); @@ -608,13 +636,6 @@ bool Cluster::IsRefVariable() const { return false; } -bool Cluster::IsControlFlow() const { - const auto &op_desc = nodes_[0]->GetOpDesc(); - bool is_ctrl_flow = kControlFlowOps.count(op_desc->GetType()) > 0 && op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); - GELOGD("[%s] %s rts control flow Op ", op_desc->GetName().c_str(), is_ctrl_flow ? "Is" : "Not"); - return is_ctrl_flow; -} - void Cluster::AddInput(ClusterPtr in) { if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return; in_clusters_.insert(in_clusters_.end(), in); @@ -694,10 +715,7 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { if (other->IsIndependent()) { return path_clusters; } - if (std::find(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()) == - other->out_clusters_.end()) { - return path_clusters; - } + path_clusters.push_back(other); forward_reached_queue.push(other); backward_reached_queue.push(shared_from_this()); @@ -761,7 +779,7 @@ InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_-> OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; Status Cluster::BuildFrame() { - if ((IsUnknownShape() || IsKnownShape() || IsInputNode()) && !IsControlFlow()) { + if (IsUnknownShape() || IsKnownShape() || IsInputNode()) { return BuildPartitionFrame(); } else { auto node = nodes_.front(); @@ -896,7 +914,7 @@ Status Cluster::CombinePartitionFrame() { } Status Cluster::BuildPartitionSubgraph() { - if (IsData() || IsNetOutput() || IsIndependent() || IsControlFlow()) { + if (IsData() || IsNetOutput() || IsIndependent()) { return SUCCESS; } int64_t parent_node_index = 0; diff --git a/ge/graph/partition/dynamic_shape_partition.h b/ge/graph/partition/dynamic_shape_partition.h index 93f86d82..f1d711eb 100644 --- a/ge/graph/partition/dynamic_shape_partition.h +++ b/ge/graph/partition/dynamic_shape_partition.h @@ -47,7 +47,6 @@ class DynamicShapePartitioner { bool IsUnknownShape() const; bool IsIndependent() const; bool IsNetOutput() const; - bool IsControlFlow() const; std::vector> Inputs() const; std::vector> Outputs() const; bool IsInputNode() const; @@ -126,13 +125,15 @@ class DynamicShapePartitioner { // and there's only one path between the two clusters , merge the two clusters // 3) Iterate through the INPUT_DATA clusters, merge all INPUT_DATA Status MergeClusters(); + // Merge clusters step0 + void MergeClustersControlFlow(); // Merge clusters step1 void MergeClustersUnknownShape(); // Merge clusters step2 void MergeClustersKnownShape(); // Merge clusters step3 void MergeClustersInputData(); - // Topological sort clusters after merge unknow shape clusters. + // Topological sort clusters after merge unknown shape clusters. Status TopologicalSortClusters(); // Deduplicate merged clusters void PruneUniqueClusters(); @@ -140,7 +141,7 @@ class DynamicShapePartitioner { Status BuildPartitionFrame(); // Establish connection between corresponding partitioned of clusters Status CombinePartitionFrame(); - // Convert the nodes in cluster into a complete ComputeGraoh + // Convert the nodes in cluster into a complete ComputeGraph Status BuildPartitionSubgraph(); // Clear resource and break circular dependency void ClearResource(); @@ -155,6 +156,8 @@ class DynamicShapePartitioner { Status CtrlEdgeTransfer(); ge::ComputeGraphPtr root_graph_; // The original graph to partition std::unordered_map> node_2_cluster_; // Record nodes and the cluster it belongs to + // V1 control flow cluster, need merge to one Graph. + std::unordered_map>> control_clusters_; // topological sorted clusters, this field will change with the splitting. // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc index 6729a647..f6c87d58 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -18,20 +18,25 @@ #include +#include "graph/utils/node_utils.h" #include "graph/common/omg_util.h" namespace ge { namespace { -const std::set kMergeOpTypes{ MERGE, REFMERGE }; +inline bool IsMergeInLoop(const NodePtr &node) { + const static std::set kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; -const std::set kSwitchOpTypes{ SWITCH, REFSWITCH }; + std::string node_type; + (void)GetOriginalType(node, node_type); + return kLoopMergeInputs.count(node_type) > 0; +} -const std::set kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; +inline bool IsSwitchInLoop(const NodePtr &node) { + const static std::set kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; -inline bool IsMergeInLoop(const NodePtr &node) { std::string node_type; (void)GetOriginalType(node, node_type); - return kLoopMergeInputs.count(node_type) > 0; + return kLoopSwitchInputs.count(node_type) > 0; } } @@ -103,7 +108,13 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: if (dst_span > 0) { search_queue.push({in_node, dst_span - 1}); } else { - switch_group.emplace_back(in_node); + const auto &all_in_nodes = in_node->GetInDataNodes(); + if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { + GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), + in_node->GetName().c_str()); + } else { + switch_group.emplace_back(in_node); + } } } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. search_queue.push({in_node, dst_span + 1}); @@ -121,19 +132,37 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: /// void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map> &switch_groups) { std::function callback = [](const NodePtr &n) { - return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); + return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); }; - for (const auto &group : switch_groups) { - const auto &node = group.first; - const auto &switch_group = group.second; - const auto &op_desc = node->GetOpDesc(); - if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) || - std::any_of(switch_group.begin(), switch_group.end(), callback)) { - GELOGI("Mark [%s] as force unknown shape", node->GetName().c_str()); - MarkForceUnknownShape(node, true); - for (const auto &n : switch_group) { - MarkForceUnknownShape(n, true); + for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { + const auto &op_node1 = it1->first; + const auto &op_desc1 = op_node1->GetOpDesc(); + if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { + continue; + } + + if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) { + int64_t group_index = op_desc1->GetId(); + GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); + MarkForceUnknownShape(op_node1, true, group_index); + for (const auto &n : it1->second) { + MarkForceUnknownShape(n, true, group_index); + } + + for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { + const auto &op_node2 = it2->first; + const auto &op_desc2 = op_node2->GetOpDesc(); + if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { + continue; + } + + if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { + MarkForceUnknownShape(op_node2, true, group_index); + for (const auto &n : it2->second) { + MarkForceUnknownShape(n, true, group_index); + } + } } } } diff --git a/ge/graph/passes/merge_to_stream_merge_pass.cc b/ge/graph/passes/merge_to_stream_merge_pass.cc index f3a437a6..4c1ad1ae 100644 --- a/ge/graph/passes/merge_to_stream_merge_pass.cc +++ b/ge/graph/passes/merge_to_stream_merge_pass.cc @@ -84,8 +84,9 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons GE_CHK_BOOL_EXEC(node != nullptr, REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); return FAILED, "Param of pre node is null."); - bool force_unknown = node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); - MarkForceUnknownShape(node, force_unknown); + int64_t group_index = -1; + bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); + MarkForceUnknownShape(node, force_unknown, group_index); for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); @@ -102,7 +103,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); return FAILED; } - MarkForceUnknownShape(active_node, force_unknown); + MarkForceUnknownShape(active_node, force_unknown, group_index); } return SUCCESS; diff --git a/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc index 5f4fc4d0..7128b3dc 100644 --- a/ge/graph/passes/next_iteration_pass.cc +++ b/ge/graph/passes/next_iteration_pass.cc @@ -18,6 +18,7 @@ #include "common/ge/ge_util.h" #include "graph/common/omg_util.h" +#include "graph/utils/node_utils.h" using std::string; @@ -203,6 +204,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { for (const auto &loop_cond_iter : loop_group_map_) { const LoopCondGroup &loop_group = *loop_cond_iter.second; const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); + const int64_t group_index = loop_group.loop_cond->GetOpDesc()->GetId(); GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge @@ -223,7 +225,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { enter_active->GetName().c_str()); return INTERNAL_ERROR; } - MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape); + MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape, group_index); } for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { @@ -253,8 +255,8 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { return INTERNAL_ERROR; } - MarkForceUnknownShape(next_node, loop_group.is_unknown_shape); - MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape); + MarkForceUnknownShape(next_node, loop_group.is_unknown_shape, group_index); + MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape, group_index); } if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || @@ -263,10 +265,10 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { return INTERNAL_ERROR; } - MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape); - MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape); - MarkForceUnknownShape(next_active, loop_group.is_unknown_shape); - HandleSwitchExitNodes(loop_group); + MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape, group_index); + MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape, group_index); + MarkForceUnknownShape(next_active, loop_group.is_unknown_shape, group_index); + HandleSwitchExitNodes(loop_group, group_index); } return SUCCESS; @@ -275,20 +277,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { /// /// @brief Mark force unknown for Exit node /// @param [in] group of LoopCond +/// @param [in] index of LoopCond Node /// @return void /// -void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group) { +void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { if (!loop_group.is_unknown_shape) { return; } for (const auto &switch_node : loop_group.switch_nodes) { - MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape); + MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape, group_index); for (const auto &node : switch_node->GetOutDataNodes()) { std::string node_type; (void)GetOriginalType(node, node_type); - if (node_type == EXIT || node_type == REFEXIT) { - MarkForceUnknownShape(node, loop_group.is_unknown_shape); + if (kExitOpTypes.count(node_type) > 0) { + MarkForceUnknownShape(node, loop_group.is_unknown_shape, group_index); } } } diff --git a/ge/graph/passes/next_iteration_pass.h b/ge/graph/passes/next_iteration_pass.h index e8786516..b6a0846d 100755 --- a/ge/graph/passes/next_iteration_pass.h +++ b/ge/graph/passes/next_iteration_pass.h @@ -96,9 +96,10 @@ class NextIterationPass : public GraphPass { /// /// @brief Mark force unknown for Exit node /// @param [in] group of LoopCond + /// @param [in] index of LoopCond Node /// @return void /// - void HandleSwitchExitNodes(const LoopCondGroup &loop_group); + void HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index); // map std::unordered_map loop_group_map_; diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index 949fff41..66a60ab9 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -369,7 +369,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), "StreamSwitch node add cond edge failed."); - MarkForceUnknownShape(stream_switch, switch_node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)); + int64_t group_index = -1; + bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); + MarkForceUnknownShape(stream_switch, force_unknown, group_index); return stream_switch; } @@ -488,11 +490,12 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) return FAILED; } - std::function callback = [](const NodePtr &n) { - return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); + int64_t group_index = -1; + std::function callback = [&group_index](const NodePtr &n) { + return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); }; bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); - MarkForceUnknownShape(active_node, is_unknown_shape); + MarkForceUnknownShape(active_node, is_unknown_shape, group_index); const std::string &cond_group = cond_node->GetName(); for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { @@ -522,7 +525,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), "Cast add data edge failed."); - MarkForceUnknownShape(stream_switch, is_unknown_shape); + MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); for (const NodePtr &node : switch_list) { GE_IF_BOOL_EXEC(node != stream_switch, { GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 4e9046e4..8597cc61 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -74,6 +74,7 @@ #include "graph/passes/unused_const_pass.h" #include "graph/passes/var_is_initialized_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" +#include "graph/passes/mark_force_unknown_for_cond_pass.h" #include "graph/preprocess/insert_op/util_insert_aipp_op.h" #include "graph/utils/type_utils.h" #include "inc/pass_manager.h" @@ -1675,6 +1676,7 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std:: PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); PP_RUN_AND_DUMP("InferFormatAndShape", FormatAndShapeProcess); + PP_RUN_AND_DUMP("CtrlFlowPreProcess", CtrlFlowPreProcess); PP_RUN_AND_DUMP("GetDynamicOutputShape", multibatch::GetDynamicOutputShape, compute_graph_); PP_RUN_AND_DUMP("ProcessAippStage2", InsertNewOpUtil::Instance().UpdateDataNodeByAipp, compute_graph_); PP_RUN("SaveOriginalGraphToOmModel", SaveOriginalGraphToOmModel); @@ -1683,6 +1685,17 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std:: return SUCCESS; } +Status GraphPrepare::CtrlFlowPreProcess() { + PassManager graph_pass; + + // After InferShape Mark v1 control flow for unknown shape. + auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; + GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::MarkForceUnknownForCondPass", mark_force_unknown_pass)); + + GE_CHK_STATUS_RET(graph_pass.Run(compute_graph_)); + return SUCCESS; +} + Status GraphPrepare::RecordAIPPInfo(ge::ComputeGraphPtr &compute_graph) { PP_RUN("RecordAIPPInfo", InsertNewOpUtil::Instance().RecordAIPPInfoToData, compute_graph_); return SUCCESS; diff --git a/ge/graph/preprocess/graph_preprocess.h b/ge/graph/preprocess/graph_preprocess.h index 9dc3e679..3eb5e03a 100755 --- a/ge/graph/preprocess/graph_preprocess.h +++ b/ge/graph/preprocess/graph_preprocess.h @@ -79,6 +79,7 @@ class GraphPrepare { Status ProcessNetOutput(); Status ProcessBeforeInfershape(); Status UpdateInputOutputByOptions(); + Status CtrlFlowPreProcess(); bool IsTansDataOpData(const ge::NodePtr &var_node); diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index aaa7801f..9ec5431a 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -104,11 +104,47 @@ void ShapeInferenceState::UpdateInputShapeFuture(int idx, ShapeFuture &&future) } } +Status ShapeInferenceState::UpdateInputForMerge(const GraphExecutionContext &context) { + int merge_index = -1; + const auto &guard = node_item.MutexGuard("UpdateInputForMerge"); + if (!AttrUtils::GetInt(node_item.op_desc, ATTR_NAME_MERGE_INPUT_INDEX, merge_index)) { + GELOGE(FAILED, "[%s] Get attr %s failed", node_item.NodeName().c_str(), ATTR_NAME_MERGE_INPUT_INDEX.c_str()); + return FAILED; + } + + if (merge_index < 0 || static_cast(merge_index) >= input_tensor_desc.size()) { + GELOGE(FAILED, "[%s] merge index: %d invalid, should in range[0, %zu)", + node_item.NodeName().c_str(), merge_index, input_tensor_desc.size()); + return FAILED; + } + + auto dst_tensor_desc = node_item.MutableInputDesc(merge_index); + GE_CHECK_NOTNULL(dst_tensor_desc); + + int64_t tensor_size = -1; + auto &tensor_desc = input_tensor_desc[merge_index]; + (void)TensorUtils::GetSize(tensor_desc, tensor_size); + + dst_tensor_desc->SetShape(tensor_desc.MutableShape()); + dst_tensor_desc->SetOriginShape(tensor_desc.GetOriginShape()); + (void)TensorUtils::SetSize(*dst_tensor_desc, tensor_size); + (void)guard; + GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], tensor size = %ld", + node_item.NodeName().c_str(), merge_index, dst_tensor_desc->GetShape().ToString().c_str(), + dst_tensor_desc->GetOriginShape().ToString().c_str(), tensor_size); + + return SUCCESS; +} + Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &context) { if (!node_item.is_dynamic) { return SUCCESS; } std::unique_lock lk(mu_); + if (node_item.IsMergeOp()) { + return UpdateInputForMerge(context); + } + if (num_pending_shapes_ > 0) { GELOGD("[%s] Await pending shape or shape future start.", node_item.NodeName().c_str()); int try_count = 0; @@ -169,7 +205,7 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex int64_t tensor_size = -1; (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); - GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu", + GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], tensor size = %ld", node_item.NodeName().c_str(), idx, src_tensor_desc->GetShape().ToString().c_str(), @@ -283,11 +319,8 @@ void NodeState::ResetContext(int group) { } switch_index_ = -1; - const auto &guard = node_item_->MutexGuard("ResetContext"); - shape_inference_state_.InitShapeState(); subgraph_context_->ResetContext(node_item_->node); GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); - (void)guard; } void NodeState::ResetSchedule() { diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index 49861611..d3f176ce 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -67,6 +67,8 @@ struct ShapeInferenceState { const NodeItem &node_item; private: + Status UpdateInputForMerge(const GraphExecutionContext &context); + friend struct NodeState; std::vector> shape_futures; // do not directly update op_desc, in case race condition across pipelines