| @@ -272,20 +272,32 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { | |||
| /// @brief Set Op _force_unknown_shape flag | |||
| /// @param [in] node | |||
| /// @param [in] force_unknown, set attribute if true | |||
| /// @param [in] group_index, condition group index of node. | |||
| /// @return | |||
| /// | |||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown) { | |||
| GE_RT_VOID_CHECK_NOTNULL(node); | |||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { | |||
| if (!force_unknown) { | |||
| return; | |||
| } | |||
| GELOGD("[%s] mark as force unknown shape node", node->GetName().c_str()); | |||
| if (!AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) { | |||
| GE_RT_VOID_CHECK_NOTNULL(node); | |||
| const auto &op_desc = node->GetOpDesc(); | |||
| GE_RT_VOID_CHECK_NOTNULL(op_desc); | |||
| // op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check. | |||
| GELOGD("Mark [%s] as force unknown shape node, group index: %ld", node->GetName().c_str(), group_index); | |||
| if (!AttrUtils::SetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) { | |||
| REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), | |||
| node->GetName().c_str(), node->GetType().c_str()); | |||
| GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), | |||
| node->GetName().c_str(), node->GetType().c_str()); | |||
| } | |||
| if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||
| REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | |||
| node->GetName().c_str(), node->GetType().c_str()); | |||
| GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | |||
| node->GetName().c_str(), node->GetType().c_str()); | |||
| } | |||
| } | |||
| } // namespace ge | |||
| @@ -129,9 +129,10 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | |||
| /// @brief Set Op _force_unknown_shape flag | |||
| /// @param [in] node | |||
| /// @param [in] force_unknown, set attribute if true | |||
| /// @param [in] group_index, condition group index of node. | |||
| /// @return | |||
| /// | |||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown); | |||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_COMMON_OMG_UTIL_H_ | |||
| @@ -65,7 +65,6 @@ | |||
| #include "graph/passes/merge_pass.h" | |||
| #include "graph/passes/merge_input_memcpy_pass.h" | |||
| #include "graph/passes/merge_to_stream_merge_pass.h" | |||
| #include "graph/passes/mark_force_unknown_for_cond_pass.h" | |||
| #include "graph/passes/multi_batch_pass.h" | |||
| #include "graph/passes/next_iteration_pass.h" | |||
| #include "graph/passes/permute_pass.h" | |||
| @@ -2582,8 +2581,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); | |||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); | |||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)); | |||
| auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; | |||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkForceUnknownForCondPass", mark_force_unknown_pass)); | |||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) | |||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) | |||
| GE_CHK_STATUS_RET( | |||
| @@ -46,11 +46,6 @@ | |||
| #define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) | |||
| namespace ge { | |||
| namespace { | |||
| const std::set<std::string> kControlFlowOps{ | |||
| STREAMACTIVE, STREAMSWITCH, STREAMMERGE, ENTER, REFENTER, LOOPCOND, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT | |||
| }; | |||
| } | |||
| using Cluster = DynamicShapePartitioner::Cluster; | |||
| using ClusterPtr = std::shared_ptr<Cluster>; | |||
| @@ -279,9 +274,17 @@ Status DynamicShapePartitioner::InitClusters() { | |||
| auto cluster = MakeShared<Cluster>(rank++, type, node, this); | |||
| REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); | |||
| node_2_cluster_[node] = cluster; | |||
| if (cluster->IsUnknownShape() && !cluster->IsControlFlow()) { | |||
| if (cluster->IsUnknownShape()) { | |||
| ordered_cluster_.push_back(cluster); | |||
| } | |||
| int64_t group_index = -1; | |||
| if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||
| GELOGD("[%s] is rts control flow Op, group index: %ld", node->GetName().c_str(), group_index); | |||
| auto &control_cluster = control_clusters_[group_index]; | |||
| control_cluster.emplace_back(cluster); | |||
| } | |||
| // Already sorted topologically, so access to the parent cluster is safe | |||
| for (const auto &parent : node->GetInAllNodes()) { | |||
| cluster->AddInput(node_2_cluster_[parent]); | |||
| @@ -350,14 +353,38 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) { | |||
| } | |||
| } | |||
| void DynamicShapePartitioner::MergeClustersControlFlow() { | |||
| for (const auto &item : control_clusters_) { | |||
| const auto &control_cluster = item.second; | |||
| auto rit = control_cluster.rbegin(); | |||
| if (rit == control_cluster.rend()) { | |||
| GELOGW("Invalid empty control flow cluster."); | |||
| continue; | |||
| } | |||
| const auto &cluster = *rit; | |||
| for (++rit; rit != control_cluster.rend(); ++rit) { | |||
| const auto &cluster_from = *rit; | |||
| auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); | |||
| GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), | |||
| ToString(merged_clusters).c_str()); | |||
| for (const auto &merged_cluster : merged_clusters) { | |||
| for (const auto &node : merged_cluster->Nodes()) { | |||
| node_2_cluster_[node] = cluster; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void DynamicShapePartitioner::MergeClustersUnknownShape() { | |||
| // Merge unknown shape clusters | |||
| for (const auto &cluster : ordered_cluster_) { | |||
| if (cluster->IsIndependent() || cluster->IsControlFlow()) { | |||
| if (cluster->IsIndependent()) { | |||
| continue; | |||
| } | |||
| for (const auto &in_cluster : cluster->Inputs()) { | |||
| if (!in_cluster->IsUnknownShape() || in_cluster->IsControlFlow()) { | |||
| if (!in_cluster->IsUnknownShape()) { | |||
| continue; | |||
| } | |||
| auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | |||
| @@ -419,6 +446,7 @@ void DynamicShapePartitioner::MergeClustersInputData() { | |||
| } | |||
| Status DynamicShapePartitioner::MergeClusters() { | |||
| MergeClustersControlFlow(); | |||
| MergeClustersUnknownShape(); | |||
| REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); | |||
| MergeClustersKnownShape(); | |||
| @@ -608,13 +636,6 @@ bool Cluster::IsRefVariable() const { | |||
| return false; | |||
| } | |||
| bool Cluster::IsControlFlow() const { | |||
| const auto &op_desc = nodes_[0]->GetOpDesc(); | |||
| bool is_ctrl_flow = kControlFlowOps.count(op_desc->GetType()) > 0 && op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||
| GELOGD("[%s] %s rts control flow Op ", op_desc->GetName().c_str(), is_ctrl_flow ? "Is" : "Not"); | |||
| return is_ctrl_flow; | |||
| } | |||
| void Cluster::AddInput(ClusterPtr in) { | |||
| if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return; | |||
| in_clusters_.insert(in_clusters_.end(), in); | |||
| @@ -694,10 +715,7 @@ std::vector<ClusterPtr> Cluster::MergeAllPathFrom(ClusterPtr other) { | |||
| if (other->IsIndependent()) { | |||
| return path_clusters; | |||
| } | |||
| if (std::find(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()) == | |||
| other->out_clusters_.end()) { | |||
| return path_clusters; | |||
| } | |||
| path_clusters.push_back(other); | |||
| forward_reached_queue.push(other); | |||
| backward_reached_queue.push(shared_from_this()); | |||
| @@ -761,7 +779,7 @@ InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_-> | |||
| OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; | |||
| Status Cluster::BuildFrame() { | |||
| if ((IsUnknownShape() || IsKnownShape() || IsInputNode()) && !IsControlFlow()) { | |||
| if (IsUnknownShape() || IsKnownShape() || IsInputNode()) { | |||
| return BuildPartitionFrame(); | |||
| } else { | |||
| auto node = nodes_.front(); | |||
| @@ -896,7 +914,7 @@ Status Cluster::CombinePartitionFrame() { | |||
| } | |||
| Status Cluster::BuildPartitionSubgraph() { | |||
| if (IsData() || IsNetOutput() || IsIndependent() || IsControlFlow()) { | |||
| if (IsData() || IsNetOutput() || IsIndependent()) { | |||
| return SUCCESS; | |||
| } | |||
| int64_t parent_node_index = 0; | |||
| @@ -47,7 +47,6 @@ class DynamicShapePartitioner { | |||
| bool IsUnknownShape() const; | |||
| bool IsIndependent() const; | |||
| bool IsNetOutput() const; | |||
| bool IsControlFlow() const; | |||
| std::vector<std::shared_ptr<Cluster>> Inputs() const; | |||
| std::vector<std::shared_ptr<Cluster>> Outputs() const; | |||
| bool IsInputNode() const; | |||
| @@ -126,13 +125,15 @@ class DynamicShapePartitioner { | |||
| // and there's only one path between the two clusters , merge the two clusters | |||
| // 3) Iterate through the INPUT_DATA clusters, merge all INPUT_DATA | |||
| Status MergeClusters(); | |||
| // Merge clusters step0 | |||
| void MergeClustersControlFlow(); | |||
| // Merge clusters step1 | |||
| void MergeClustersUnknownShape(); | |||
| // Merge clusters step2 | |||
| void MergeClustersKnownShape(); | |||
| // Merge clusters step3 | |||
| void MergeClustersInputData(); | |||
| // Topological sort clusters after merge unknow shape clusters. | |||
| // Topological sort clusters after merge unknown shape clusters. | |||
| Status TopologicalSortClusters(); | |||
| // Deduplicate merged clusters | |||
| void PruneUniqueClusters(); | |||
| @@ -140,7 +141,7 @@ class DynamicShapePartitioner { | |||
| Status BuildPartitionFrame(); | |||
| // Establish connection between corresponding partitioned of clusters | |||
| Status CombinePartitionFrame(); | |||
| // Convert the nodes in cluster into a complete ComputeGraoh | |||
| // Convert the nodes in cluster into a complete ComputeGraph | |||
| Status BuildPartitionSubgraph(); | |||
| // Clear resource and break circular dependency | |||
| void ClearResource(); | |||
| @@ -155,6 +156,8 @@ class DynamicShapePartitioner { | |||
| Status CtrlEdgeTransfer(); | |||
| ge::ComputeGraphPtr root_graph_; // The original graph to partition | |||
| std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | |||
| // V1 control flow cluster, need merge to one Graph. | |||
| std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_; | |||
| // topological sorted clusters, this field will change with the splitting. | |||
| // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | |||
| // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | |||
| @@ -18,20 +18,25 @@ | |||
| #include <queue> | |||
| #include "graph/utils/node_utils.h" | |||
| #include "graph/common/omg_util.h" | |||
| namespace ge { | |||
| namespace { | |||
| const std::set<std::string> kMergeOpTypes{ MERGE, REFMERGE }; | |||
| inline bool IsMergeInLoop(const NodePtr &node) { | |||
| const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||
| const std::set<std::string> kSwitchOpTypes{ SWITCH, REFSWITCH }; | |||
| std::string node_type; | |||
| (void)GetOriginalType(node, node_type); | |||
| return kLoopMergeInputs.count(node_type) > 0; | |||
| } | |||
| const std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||
| inline bool IsSwitchInLoop(const NodePtr &node) { | |||
| const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; | |||
| inline bool IsMergeInLoop(const NodePtr &node) { | |||
| std::string node_type; | |||
| (void)GetOriginalType(node, node_type); | |||
| return kLoopMergeInputs.count(node_type) > 0; | |||
| return kLoopSwitchInputs.count(node_type) > 0; | |||
| } | |||
| } | |||
| @@ -103,7 +108,13 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||
| if (dst_span > 0) { | |||
| search_queue.push({in_node, dst_span - 1}); | |||
| } else { | |||
| switch_group.emplace_back(in_node); | |||
| const auto &all_in_nodes = in_node->GetInDataNodes(); | |||
| if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { | |||
| GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), | |||
| in_node->GetName().c_str()); | |||
| } else { | |||
| switch_group.emplace_back(in_node); | |||
| } | |||
| } | |||
| } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | |||
| search_queue.push({in_node, dst_span + 1}); | |||
| @@ -121,19 +132,37 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||
| /// | |||
| void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | |||
| std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); | |||
| }; | |||
| for (const auto &group : switch_groups) { | |||
| const auto &node = group.first; | |||
| const auto &switch_group = group.second; | |||
| const auto &op_desc = node->GetOpDesc(); | |||
| if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) || | |||
| std::any_of(switch_group.begin(), switch_group.end(), callback)) { | |||
| GELOGI("Mark [%s] as force unknown shape", node->GetName().c_str()); | |||
| MarkForceUnknownShape(node, true); | |||
| for (const auto &n : switch_group) { | |||
| MarkForceUnknownShape(n, true); | |||
| for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { | |||
| const auto &op_node1 = it1->first; | |||
| const auto &op_desc1 = op_node1->GetOpDesc(); | |||
| if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
| continue; | |||
| } | |||
| if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) { | |||
| int64_t group_index = op_desc1->GetId(); | |||
| GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); | |||
| MarkForceUnknownShape(op_node1, true, group_index); | |||
| for (const auto &n : it1->second) { | |||
| MarkForceUnknownShape(n, true, group_index); | |||
| } | |||
| for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { | |||
| const auto &op_node2 = it2->first; | |||
| const auto &op_desc2 = op_node2->GetOpDesc(); | |||
| if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
| continue; | |||
| } | |||
| if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { | |||
| MarkForceUnknownShape(op_node2, true, group_index); | |||
| for (const auto &n : it2->second) { | |||
| MarkForceUnknownShape(n, true, group_index); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -84,8 +84,9 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||
| GE_CHK_BOOL_EXEC(node != nullptr, | |||
| REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | |||
| return FAILED, "Param of pre node is null."); | |||
| bool force_unknown = node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||
| MarkForceUnknownShape(node, force_unknown); | |||
| int64_t group_index = -1; | |||
| bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| MarkForceUnknownShape(node, force_unknown, group_index); | |||
| for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | |||
| @@ -102,7 +103,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||
| GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| MarkForceUnknownShape(active_node, force_unknown); | |||
| MarkForceUnknownShape(active_node, force_unknown, group_index); | |||
| } | |||
| return SUCCESS; | |||
| @@ -18,6 +18,7 @@ | |||
| #include "common/ge/ge_util.h" | |||
| #include "graph/common/omg_util.h" | |||
| #include "graph/utils/node_utils.h" | |||
| using std::string; | |||
| @@ -203,6 +204,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
| for (const auto &loop_cond_iter : loop_group_map_) { | |||
| const LoopCondGroup &loop_group = *loop_cond_iter.second; | |||
| const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); | |||
| const int64_t group_index = loop_group.loop_cond->GetOpDesc()->GetId(); | |||
| GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | |||
| // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge | |||
| @@ -223,7 +225,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
| enter_active->GetName().c_str()); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape); | |||
| MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape, group_index); | |||
| } | |||
| for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | |||
| @@ -253,8 +255,8 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
| return INTERNAL_ERROR; | |||
| } | |||
| MarkForceUnknownShape(next_node, loop_group.is_unknown_shape); | |||
| MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape); | |||
| MarkForceUnknownShape(next_node, loop_group.is_unknown_shape, group_index); | |||
| MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape, group_index); | |||
| } | |||
| if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | |||
| @@ -263,10 +265,10 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
| return INTERNAL_ERROR; | |||
| } | |||
| MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape); | |||
| MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape); | |||
| MarkForceUnknownShape(next_active, loop_group.is_unknown_shape); | |||
| HandleSwitchExitNodes(loop_group); | |||
| MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape, group_index); | |||
| MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape, group_index); | |||
| MarkForceUnknownShape(next_active, loop_group.is_unknown_shape, group_index); | |||
| HandleSwitchExitNodes(loop_group, group_index); | |||
| } | |||
| return SUCCESS; | |||
| @@ -275,20 +277,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
| /// | |||
| /// @brief Mark force unknown for Exit node | |||
| /// @param [in] group of LoopCond | |||
| /// @param [in] index of LoopCond Node | |||
| /// @return void | |||
| /// | |||
| void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group) { | |||
| void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | |||
| if (!loop_group.is_unknown_shape) { | |||
| return; | |||
| } | |||
| for (const auto &switch_node : loop_group.switch_nodes) { | |||
| MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape); | |||
| MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape, group_index); | |||
| for (const auto &node : switch_node->GetOutDataNodes()) { | |||
| std::string node_type; | |||
| (void)GetOriginalType(node, node_type); | |||
| if (node_type == EXIT || node_type == REFEXIT) { | |||
| MarkForceUnknownShape(node, loop_group.is_unknown_shape); | |||
| if (kExitOpTypes.count(node_type) > 0) { | |||
| MarkForceUnknownShape(node, loop_group.is_unknown_shape, group_index); | |||
| } | |||
| } | |||
| } | |||
| @@ -96,9 +96,10 @@ class NextIterationPass : public GraphPass { | |||
| /// | |||
| /// @brief Mark force unknown for Exit node | |||
| /// @param [in] group of LoopCond | |||
| /// @param [in] index of LoopCond Node | |||
| /// @return void | |||
| /// | |||
| void HandleSwitchExitNodes(const LoopCondGroup &loop_group); | |||
| void HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index); | |||
| // map<frame_name, LoopCondGroup> | |||
| std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | |||
| @@ -369,7 +369,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||
| GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), | |||
| "StreamSwitch node add cond edge failed."); | |||
| MarkForceUnknownShape(stream_switch, switch_node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)); | |||
| int64_t group_index = -1; | |||
| bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| MarkForceUnknownShape(stream_switch, force_unknown, group_index); | |||
| return stream_switch; | |||
| } | |||
| @@ -488,11 +490,12 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
| return FAILED; | |||
| } | |||
| std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||
| int64_t group_index = -1; | |||
| std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { | |||
| return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| }; | |||
| bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||
| MarkForceUnknownShape(active_node, is_unknown_shape); | |||
| MarkForceUnknownShape(active_node, is_unknown_shape, group_index); | |||
| const std::string &cond_group = cond_node->GetName(); | |||
| for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | |||
| @@ -522,7 +525,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
| GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), | |||
| "Cast add data edge failed."); | |||
| MarkForceUnknownShape(stream_switch, is_unknown_shape); | |||
| MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); | |||
| for (const NodePtr &node : switch_list) { | |||
| GE_IF_BOOL_EXEC(node != stream_switch, { | |||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | |||
| @@ -74,6 +74,7 @@ | |||
| #include "graph/passes/unused_const_pass.h" | |||
| #include "graph/passes/var_is_initialized_op_pass.h" | |||
| #include "graph/passes/variable_prepare_op_pass.h" | |||
| #include "graph/passes/mark_force_unknown_for_cond_pass.h" | |||
| #include "graph/preprocess/insert_op/util_insert_aipp_op.h" | |||
| #include "graph/utils/type_utils.h" | |||
| #include "inc/pass_manager.h" | |||
| @@ -1675,6 +1676,7 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std:: | |||
| PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); | |||
| PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); | |||
| PP_RUN_AND_DUMP("InferFormatAndShape", FormatAndShapeProcess); | |||
| PP_RUN_AND_DUMP("CtrlFlowPreProcess", CtrlFlowPreProcess); | |||
| PP_RUN_AND_DUMP("GetDynamicOutputShape", multibatch::GetDynamicOutputShape, compute_graph_); | |||
| PP_RUN_AND_DUMP("ProcessAippStage2", InsertNewOpUtil::Instance().UpdateDataNodeByAipp, compute_graph_); | |||
| PP_RUN("SaveOriginalGraphToOmModel", SaveOriginalGraphToOmModel); | |||
| @@ -1683,6 +1685,17 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std:: | |||
| return SUCCESS; | |||
| } | |||
| Status GraphPrepare::CtrlFlowPreProcess() { | |||
| PassManager graph_pass; | |||
| // After InferShape Mark v1 control flow for unknown shape. | |||
| auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; | |||
| GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::MarkForceUnknownForCondPass", mark_force_unknown_pass)); | |||
| GE_CHK_STATUS_RET(graph_pass.Run(compute_graph_)); | |||
| return SUCCESS; | |||
| } | |||
| Status GraphPrepare::RecordAIPPInfo(ge::ComputeGraphPtr &compute_graph) { | |||
| PP_RUN("RecordAIPPInfo", InsertNewOpUtil::Instance().RecordAIPPInfoToData, compute_graph_); | |||
| return SUCCESS; | |||
| @@ -79,6 +79,7 @@ class GraphPrepare { | |||
| Status ProcessNetOutput(); | |||
| Status ProcessBeforeInfershape(); | |||
| Status UpdateInputOutputByOptions(); | |||
| Status CtrlFlowPreProcess(); | |||
| bool IsTansDataOpData(const ge::NodePtr &var_node); | |||
| @@ -104,11 +104,47 @@ void ShapeInferenceState::UpdateInputShapeFuture(int idx, ShapeFuture &&future) | |||
| } | |||
| } | |||
| Status ShapeInferenceState::UpdateInputForMerge(const GraphExecutionContext &context) { | |||
| int merge_index = -1; | |||
| const auto &guard = node_item.MutexGuard("UpdateInputForMerge"); | |||
| if (!AttrUtils::GetInt(node_item.op_desc, ATTR_NAME_MERGE_INPUT_INDEX, merge_index)) { | |||
| GELOGE(FAILED, "[%s] Get attr %s failed", node_item.NodeName().c_str(), ATTR_NAME_MERGE_INPUT_INDEX.c_str()); | |||
| return FAILED; | |||
| } | |||
| if (merge_index < 0 || static_cast<size_t>(merge_index) >= input_tensor_desc.size()) { | |||
| GELOGE(FAILED, "[%s] merge index: %d invalid, should in range[0, %zu)", | |||
| node_item.NodeName().c_str(), merge_index, input_tensor_desc.size()); | |||
| return FAILED; | |||
| } | |||
| auto dst_tensor_desc = node_item.MutableInputDesc(merge_index); | |||
| GE_CHECK_NOTNULL(dst_tensor_desc); | |||
| int64_t tensor_size = -1; | |||
| auto &tensor_desc = input_tensor_desc[merge_index]; | |||
| (void)TensorUtils::GetSize(tensor_desc, tensor_size); | |||
| dst_tensor_desc->SetShape(tensor_desc.MutableShape()); | |||
| dst_tensor_desc->SetOriginShape(tensor_desc.GetOriginShape()); | |||
| (void)TensorUtils::SetSize(*dst_tensor_desc, tensor_size); | |||
| (void)guard; | |||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], tensor size = %ld", | |||
| node_item.NodeName().c_str(), merge_index, dst_tensor_desc->GetShape().ToString().c_str(), | |||
| dst_tensor_desc->GetOriginShape().ToString().c_str(), tensor_size); | |||
| return SUCCESS; | |||
| } | |||
| Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &context) { | |||
| if (!node_item.is_dynamic) { | |||
| return SUCCESS; | |||
| } | |||
| std::unique_lock<std::mutex> lk(mu_); | |||
| if (node_item.IsMergeOp()) { | |||
| return UpdateInputForMerge(context); | |||
| } | |||
| if (num_pending_shapes_ > 0) { | |||
| GELOGD("[%s] Await pending shape or shape future start.", node_item.NodeName().c_str()); | |||
| int try_count = 0; | |||
| @@ -169,7 +205,7 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||
| int64_t tensor_size = -1; | |||
| (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | |||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu", | |||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], tensor size = %ld", | |||
| node_item.NodeName().c_str(), | |||
| idx, | |||
| src_tensor_desc->GetShape().ToString().c_str(), | |||
| @@ -283,11 +319,8 @@ void NodeState::ResetContext(int group) { | |||
| } | |||
| switch_index_ = -1; | |||
| const auto &guard = node_item_->MutexGuard("ResetContext"); | |||
| shape_inference_state_.InitShapeState(); | |||
| subgraph_context_->ResetContext(node_item_->node); | |||
| GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); | |||
| (void)guard; | |||
| } | |||
| void NodeState::ResetSchedule() { | |||
| @@ -67,6 +67,8 @@ struct ShapeInferenceState { | |||
| const NodeItem &node_item; | |||
| private: | |||
| Status UpdateInputForMerge(const GraphExecutionContext &context); | |||
| friend struct NodeState; | |||
| std::vector<std::pair<int, ShapeFuture>> shape_futures; | |||
| // do not directly update op_desc, in case race condition across pipelines | |||