diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 3a0f7638..3da80492 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -157,6 +157,8 @@ set(TRAIN_SRC_LIST "graph/passes/compile_nodes_pass.cc" "graph/passes/constant_folding_pass.cc" "graph/passes/constant_fuse_same_pass.cc" + "graph/passes/remove_same_const_pass.cc" + "graph/passes/useless_control_out_remove_pass.cc" "graph/passes/control_trigger_pass.cc" "graph/passes/dimension_adjust_pass.cc" "graph/passes/dimension_compute_pass.cc" @@ -522,6 +524,8 @@ set(INFER_SRC_LIST "graph/passes/assign_pass.cc" "graph/passes/addn_pass.cc" "graph/passes/common_subexpression_elimination_pass.cc" + "graph/passes/remove_same_const_pass.cc" + "graph/passes/useless_control_out_remove_pass.cc" "graph/passes/transop_symmetry_elimination_pass.cc" "graph/passes/save_pass.cc" "graph/passes/switch_dead_branch_elimination.cc" diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index bfb612ea..e20456d5 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -191,6 +191,8 @@ OMG_HOST_SRC_FILES := \ graph/passes/control_trigger_pass.cc \ graph/passes/cond_pass.cc \ graph/passes/cond_remove_pass.cc \ + graph/passes/remove_same_const_pass.cc \ + graph/passes/useless_control_out_remove_pass.cc \ graph/passes/for_pass.cc \ graph/passes/enter_pass.cc \ graph/passes/assign_pass.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 25718e9b..9706dadb 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -126,6 +126,8 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/compile_nodes_pass.cc \ graph/passes/constant_folding_pass.cc \ graph/passes/constant_fuse_same_pass.cc \ + graph/passes/remove_same_const_pass.cc \ + graph/passes/useless_control_out_remove_pass.cc \ graph/passes/control_trigger_pass.cc \ graph/passes/dimension_adjust_pass.cc \ graph/passes/dimension_compute_pass.cc \ diff --git a/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc index 37eb499a..3be45895 100755 --- a/ge/graph/build/model_builder.cc +++ b/ge/graph/build/model_builder.cc @@ -224,6 +224,7 @@ Status ModelBuilder::AdjustConstWeightSize(const ge::NodePtr &node, size_t &mem_ GeTensorDesc &tensor_desc = weight->MutableTensorDesc(); size_t output_size = weight->GetData().size(); TensorUtils::SetDataOffset(tensor_desc, mem_offset); + GELOGD("Node: %s, weight size: %zu.", node->GetName().c_str(), output_size); mem_offset += output_size; } return SUCCESS; diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 46a0ec2e..46799ba3 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -65,6 +65,7 @@ #include "graph/passes/permute_pass.h" #include "graph/passes/prune_pass.h" #include "graph/passes/ref_identity_delete_op_pass.h" +#include "graph/passes/remove_same_const_pass.h" #include "graph/passes/reshape_recovery_pass.h" #include "graph/passes/reshape_remove_pass.h" #include "graph/passes/same_transdata_breadth_fusion_pass.h" @@ -78,6 +79,7 @@ #include "graph/passes/transop_symmetry_elimination_pass.h" #include "graph/passes/transop_without_reshape_fusion_pass.h" #include "graph/passes/transpose_transdata_pass.h" +#include "graph/passes/useless_control_out_remove_pass.h" #include "graph/passes/variable_op_pass.h" #include "graph/passes/variable_ref_delete_op_pass.h" #include "graph/passes/variable_ref_useless_control_out_delete_pass.h" @@ -2130,6 +2132,7 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { TransposeTransDataPass transpose_transdata_pass; TransOpSymmetryEliminationPass symmetry_elimination_pass; DimensionComputePass dimension_compute_pass; + UselessControlOutRemovePass useless_control_out_remove_pass; names_to_passes.emplace_back("EnterPass", &enter_pass); names_to_passes.emplace_back("AddNPass", &addn_pass); names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); @@ -2143,6 +2146,7 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); + names_to_passes.emplace_back("UselessControlOutRemovePass", &useless_control_out_remove_pass); GE_TIMESTAMP_START(names_to_passes); ret = GEPass(compute_graph).Run(names_to_passes); GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2"); @@ -2183,6 +2187,8 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass", new (std::nothrow) VariableRefUselessControlOutDeletePass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass)) + GE_CHK_STATUS_RET( + graph_pass.AddPass("OptimizeStage1_3::RemoveSameConstPass", new (std::nothrow) RemoveSameConstPass)) if (options_.train_graph_flag) { // Priority: The GlobalStepInsertPass should work before graph partitioner. // Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory diff --git a/ge/graph/passes/attach_stream_label_pass.cc b/ge/graph/passes/attach_stream_label_pass.cc index c0e0f669..cd3509c7 100644 --- a/ge/graph/passes/attach_stream_label_pass.cc +++ b/ge/graph/passes/attach_stream_label_pass.cc @@ -18,6 +18,8 @@ #include "ge/ge_api_types.h" #include "graph/common/omg_util.h" +using std::string; + namespace ge { Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { GELOGD("AttachStreamLabelPass Enter."); @@ -187,21 +189,10 @@ Status AttachStreamLabelPass::UpdateEnterNode() { } std::stack enter_nodes; - std::string batch_label; for (const auto &enter_node : pair.second) { enter_nodes.emplace(enter_node); - std::string tmp_label; - (void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); - if (!tmp_label.empty()) { - if (batch_label.empty()) { - batch_label = tmp_label; - } else if (batch_label != tmp_label) { - GELOGE(FAILED, "multi batch_label exist, label1=%s, label2=%s.", batch_label.c_str(), tmp_label.c_str()); - return FAILED; - } - } } - if (UpdateLoopBranch(enter_nodes, active_label_list[0], batch_label) != SUCCESS) { + if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { GELOGE(FAILED, "Update stream_label for loop_branch failed."); return FAILED; } @@ -226,10 +217,7 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_no } for (const auto &enter_node : enter_nodes) { - GE_CHECK_NOTNULL(enter_node->GetOpDesc()); - if (enter_node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL)) { - GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); - } + GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); } return SUCCESS; } @@ -241,8 +229,7 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_no /// @param [in] batch_label /// @return Status /// -Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label, - const std::string &batch_label) { +Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_nodes, const string &stream_label) { std::stack nodes(enter_nodes); NodePtr cur_node = nullptr; while (!nodes.empty()) { @@ -251,11 +238,6 @@ Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_ for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { OpDescPtr out_desc = out_node->GetOpDesc(); GE_CHECK_NOTNULL(out_desc); - std::string tmp_label; - (void)AttrUtils::GetStr(out_desc, ATTR_NAME_BATCH_LABEL, tmp_label); - if (!tmp_label.empty() && (tmp_label != batch_label)) { - continue; - } std::string out_type = out_desc->GetType(); bool need_skip = out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) || diff --git a/ge/graph/passes/attach_stream_label_pass.h b/ge/graph/passes/attach_stream_label_pass.h index 19f11480..ad71d58f 100755 --- a/ge/graph/passes/attach_stream_label_pass.h +++ b/ge/graph/passes/attach_stream_label_pass.h @@ -58,11 +58,9 @@ class AttachStreamLabelPass : public GraphPass { /// @brief Update stream_label for loop_branch /// @param [in] enter_nodes /// @param [in] stream_label - /// @param [in] batch_label /// @return Status /// - static Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label, - const std::string &batch_label); + static Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label); /// /// @brief Update stream_label start with enter nodes diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index 68efbeb9..3b854c18 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -96,7 +96,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder node->GetName().c_str(), node->GetType().c_str()); continue; } - if (node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { + if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str()); nodes_re_pass.insert(node_to_re_pass); } else { diff --git a/ge/graph/passes/dimension_adjust_pass.cc b/ge/graph/passes/dimension_adjust_pass.cc index fc5fe69f..5701faf5 100755 --- a/ge/graph/passes/dimension_adjust_pass.cc +++ b/ge/graph/passes/dimension_adjust_pass.cc @@ -80,7 +80,71 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { } } + ret = DealWithInNodes(node); + if (ret != SUCCESS) { + GELOGE(ret, "DealWithInNodes of %s failed.", node->GetName().c_str()); + return ret; + } + std::vector data_relink_io_map = {kDataInputIndex}; return IsolateAndDeleteNode(node, data_relink_io_map); } + +Status DimensionAdjustPass::DealWithInNodes(NodePtr &node) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + auto graph = node->GetOwnerComputeGraph(); + auto in_data_anchors = node->GetAllInDataAnchors(); + for (auto &in_data_anchor : in_data_anchors) { + if (in_data_anchor == nullptr) { + continue; + } + auto in_node_anchor = in_data_anchor->GetPeerOutAnchor(); + if (in_node_anchor == nullptr) { + continue; + } + auto in_node = in_node_anchor->GetOwnerNode(); + if (in_node->GetType() == SWITCHN) { + auto identity_name = node->GetName() + "_ctrl_identity_" + std::to_string(in_data_anchor->GetIdx()); + auto identity = + AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph); + GE_CHECK_NOTNULL(identity); + GELOGI("Create new identity node[%s] after node %s[type: %s] success.", identity->GetName().c_str(), + in_node->GetName().c_str(), in_node->GetType().c_str()); + GE_CHK_STATUS_RET(GraphUtils::AddEdge(in_node_anchor, identity->GetInDataAnchor(0))) + GE_CHECK_NOTNULL(identity->GetOutControlAnchor()); + if (identity->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor())) { + continue; + } + GE_CHK_STATUS_RET(GraphUtils::AddEdge(identity->GetOutControlAnchor(), node->GetInControlAnchor())) + } + } + + return SUCCESS; +} + +NodePtr DimensionAdjustPass::AddIdentityNodeToGraph(const string &name, const GeTensorDesc &tensor, + ComputeGraphPtr &graph) { + if (graph == nullptr) { + GELOGE(INTERNAL_ERROR, "Comput graph ptr is null in creating identity node."); + return nullptr; + } + + OpDescPtr desc = MakeShared("", ""); + if (desc == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to create op desc."); + return nullptr; + } + + desc->SetName(name); + desc->SetType(IDENTITY); + auto ret = desc->AddInputDesc(tensor); + auto ret2 = desc->AddOutputDesc(tensor); + if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) { + GELOGE(INTERNAL_ERROR, "Failed to add input/output desc in creating identity."); + return nullptr; + } + + return graph->AddNodeFront(desc); +} } // namespace ge diff --git a/ge/graph/passes/dimension_adjust_pass.h b/ge/graph/passes/dimension_adjust_pass.h index 685d9694..7766f140 100755 --- a/ge/graph/passes/dimension_adjust_pass.h +++ b/ge/graph/passes/dimension_adjust_pass.h @@ -34,6 +34,10 @@ namespace ge { class DimensionAdjustPass : public BaseNodePass { public: Status Run(ge::NodePtr &node) override; + + private: + Status DealWithInNodes(ge::NodePtr &node); + NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tensor, ComputeGraphPtr &graph); }; } // namespace ge diff --git a/ge/graph/passes/enter_pass.cc b/ge/graph/passes/enter_pass.cc index afeca78f..066c97cf 100644 --- a/ge/graph/passes/enter_pass.cc +++ b/ge/graph/passes/enter_pass.cc @@ -23,6 +23,7 @@ namespace { const size_t kOutNodesNum = 1; +const size_t kInCtrlNodesNum = 1; } namespace ge { @@ -55,6 +56,7 @@ Status EnterPass::Run(NodePtr &node) { if (out_ctrl_node == nullptr) { continue; } + GELOGI("Remove control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str()); if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) { GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(), out_ctrl_node->GetName().c_str()); @@ -62,8 +64,12 @@ Status EnterPass::Run(NodePtr &node) { } } } else { - if (OptimizeEnter(node, in_node) != SUCCESS) { - GELOGE(FAILED, "Optimize enter node[%s] failed.", node->GetName().c_str()); + if (OptimizeEnterWithOnlyDataOut(node, in_node) != SUCCESS) { + GELOGE(FAILED, "Optimize enter node[%s] with only out data node failed.", node->GetName().c_str()); + return FAILED; + } + if (UnlinkCtrlEdgeBeforeConst(node) != SUCCESS) { + GELOGE(FAILED, "Unlink control edge before const of node[%s]'s out nodes failed.", node->GetName().c_str()); return FAILED; } } @@ -72,7 +78,7 @@ Status EnterPass::Run(NodePtr &node) { return SUCCESS; } -Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { +Status EnterPass::OptimizeEnterWithOnlyDataOut(NodePtr &node, NodePtr &in_node) { if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { return SUCCESS; } @@ -83,17 +89,61 @@ Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { } GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); - GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); + GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))) const auto &out_data_anchor = node->GetOutDataAnchor(0); GE_CHECK_NOTNULL(out_data_anchor); for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); - GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); + GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)) + GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)) } - GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)); + GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)) AddNodeDeleted(node); AddRePassNodesWithInOut(in_node); return SUCCESS; } + +Status EnterPass::UnlinkCtrlEdgeBeforeConst(NodePtr &node) { + auto out_ctrl_nodes = node->GetOutControlNodes(); + if (out_ctrl_nodes.empty()) { + return SUCCESS; + } + auto out_ctrl_anchor = node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(out_ctrl_anchor); + + for (auto &out_ctrl_node : out_ctrl_nodes) { + GE_CHECK_NOTNULL(out_ctrl_node); + if ((out_ctrl_node->GetType() != CONSTANT) && (out_ctrl_node->GetType() != CONSTANTOP)) { + continue; + } + auto in_ctrl_nodes = out_ctrl_node->GetInControlNodes(); + if (in_ctrl_nodes.size() != kInCtrlNodesNum) { + continue; + } + + // Skip when has merge out + bool has_merge_out = false; + auto out_nodes_of_const = out_ctrl_node->GetOutAllNodes(); + for (const auto &out_node_of_const : out_nodes_of_const) { + GE_CHECK_NOTNULL(out_node_of_const); + if (out_node_of_const->GetType() == MERGE || out_node_of_const->GetType() == REFMERGE) { + has_merge_out = true; + break; + } + } + if (has_merge_out) { + continue; + } + + GELOGI("Unlink control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str()); + GE_CHK_STATUS_RET(out_ctrl_anchor->Unlink(out_ctrl_node->GetInControlAnchor())) + for (auto &out_node_of_const : out_nodes_of_const) { + if (!out_ctrl_anchor->IsLinkedWith(out_node_of_const->GetInControlAnchor())) { + GELOGI("Link control edge from %s to %s.", node->GetName().c_str(), out_node_of_const->GetName().c_str()); + GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(out_node_of_const->GetInControlAnchor())) + } + } + } + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/passes/enter_pass.h b/ge/graph/passes/enter_pass.h index 677516ff..1417b1f0 100644 --- a/ge/graph/passes/enter_pass.h +++ b/ge/graph/passes/enter_pass.h @@ -25,7 +25,8 @@ class EnterPass : public BaseNodePass { Status Run(NodePtr &node) override; private: - Status OptimizeEnter(NodePtr &node, NodePtr &in_node); + Status OptimizeEnterWithOnlyDataOut(NodePtr &node, NodePtr &in_node); + Status UnlinkCtrlEdgeBeforeConst(NodePtr &node); }; } // namespace ge #endif // GE_GRAPH_PASSES_ENTER_PASS_H_ diff --git a/ge/graph/passes/folding_pass.cc b/ge/graph/passes/folding_pass.cc index 93dc2c40..227a0f61 100755 --- a/ge/graph/passes/folding_pass.cc +++ b/ge/graph/passes/folding_pass.cc @@ -173,10 +173,7 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) { continue; } auto in_node = in_node_anchor->GetOwnerNode(); - if (in_node == nullptr) { - continue; - } - if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) { + if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH) || (in_node->GetType() == SWITCHN)) { GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); auto ret = in_node_anchor->Unlink(in_data_anchor); if (ret != SUCCESS) { diff --git a/ge/graph/passes/merge_to_stream_merge_pass.cc b/ge/graph/passes/merge_to_stream_merge_pass.cc index 103fbb1b..c1a57a61 100644 --- a/ge/graph/passes/merge_to_stream_merge_pass.cc +++ b/ge/graph/passes/merge_to_stream_merge_pass.cc @@ -89,16 +89,6 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); } - if (merge_op_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) { - string batch_label; - (void)AttrUtils::GetStr(merge_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); - if (!batch_label.empty()) { - auto stream_merge_desc = stream_merge->GetOpDesc(); - GE_CHECK_NOTNULL(stream_merge_desc); - (void)AttrUtils::SetStr(stream_merge_desc, ATTR_NAME_BATCH_LABEL, batch_label); - } - } - return AddActiveNodes(graph, stream_merge); } diff --git a/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc index d8c4779d..cf46f09d 100644 --- a/ge/graph/passes/next_iteration_pass.cc +++ b/ge/graph/passes/next_iteration_pass.cc @@ -19,6 +19,8 @@ #include "common/ge/ge_util.h" #include "graph/common/omg_util.h" +using std::string; + namespace ge { Status NextIterationPass::Run(ComputeGraphPtr graph) { GELOGD("NextIterationPass Enter"); @@ -35,10 +37,6 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { return INTERNAL_ERROR; } } - if (GroupWithNoBatch(graph) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Group enter_nodes failed without batch_label attr."); - return INTERNAL_ERROR; - } if (FindWhileGroups() != SUCCESS) { GELOGE(INTERNAL_ERROR, "Find while groups failed."); @@ -73,75 +71,22 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { return FAILED; } - std::string batch_label; - (void)ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label); - if (batch_label.empty()) { - auto frame_iter = frame_enter_map_.find(frame_name); - if (frame_iter == frame_enter_map_.end()) { - std::vector enter_nodes; - enter_nodes.emplace_back(enter_node); - frame_enter_map_[frame_name] = enter_nodes; - } else { - frame_iter->second.emplace_back(enter_node); - } - return SUCCESS; + string batch_label; + if (ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { + frame_name += batch_label; } - auto group_iter = loop_group_map_.find(frame_name); - if (group_iter == loop_group_map_.end()) { + auto iter = loop_group_map_.find(frame_name); + if (iter == loop_group_map_.end()) { LoopCondGroupPtr loop_group = MakeShared(); if (loop_group == nullptr) { GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); return FAILED; } loop_group->enter_nodes.emplace_back(enter_node); - loop_group_map_[frame_name][batch_label] = loop_group; + loop_group_map_[frame_name] = loop_group; } else { - auto batch_iter = group_iter->second.find(batch_label); - if (batch_iter == group_iter->second.end()) { - LoopCondGroupPtr loop_group = MakeShared(); - if (loop_group == nullptr) { - GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); - return FAILED; - } - loop_group->enter_nodes.emplace_back(enter_node); - group_iter->second[batch_label] = loop_group; - } else { - batch_iter->second->enter_nodes.emplace_back(enter_node); - } - } - - return SUCCESS; -} - -/// -/// @brief Group Enter nodes without batch_label attr -/// @param [in] compute_graph -/// @return Status -/// -Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) { - if (frame_enter_map_.empty()) { - GELOGI("All enter nodes in graph %s has batch_label attr.", graph->GetName().c_str()); - return SUCCESS; - } - for (const auto &item : frame_enter_map_) { - const std::string &frame_name = item.first; - auto iter = loop_group_map_.find(frame_name); - if (iter == loop_group_map_.end()) { - LoopCondGroupPtr loop_group = MakeShared(); - if (loop_group == nullptr) { - GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); - return FAILED; - } - loop_group->enter_nodes = item.second; - loop_group_map_[frame_name][""] = loop_group; - } else { - for (auto &batch_item : iter->second) { - for (const auto &enter_node : item.second) { - batch_item.second->enter_nodes.emplace_back(enter_node); - } - } - } + iter->second->enter_nodes.emplace_back(enter_node); } return SUCCESS; @@ -154,55 +99,39 @@ Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) { Status NextIterationPass::FindWhileGroups() { for (const auto &loop_group_iter : loop_group_map_) { const std::string &frame_name = loop_group_iter.first; - for (const auto &batch_iter : loop_group_iter.second) { - const std::string &batch_label = batch_iter.first; - for (const auto &enter_node : batch_iter.second->enter_nodes) { - for (const auto &out_node : enter_node->GetOutAllNodes()) { - GELOGI("Find while_group for enter_node %s, frame_name:%s, batch_label:%s.", enter_node->GetName().c_str(), - frame_name.c_str(), batch_label.c_str()); - if ((out_node->GetType() != MERGE) && (out_node->GetType() != REFMERGE)) { - continue; - } - std::string tmp_label; - GE_CHECK_NOTNULL(out_node->GetOpDesc()); - (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); - bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); - if (need_skip) { - continue; - } - - NodePtr next_node = nullptr; - if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, - "Get NextIteration node failed: inputs of Merge should be Enter/NextIteration, current_Merge=%s", - out_node->GetName().c_str()); - return INTERNAL_ERROR; - } - batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); - - NodePtr switch_node = nullptr; - if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get Switch node failed: output of Merge should be Switch, current_Merge=%s", - out_node->GetName().c_str()); - return INTERNAL_ERROR; - } - if (switch_node == nullptr) { - continue; - } - - NodePtr loop_cond = nullptr; - if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) { - GELOGE(INTERNAL_ERROR, - "Get LoopCond node failed: pred input of Switch should be LoopCond, current_Switch=%s", - switch_node->GetName().c_str()); - return INTERNAL_ERROR; - } - if (batch_iter.second->loop_cond == nullptr) { - batch_iter.second->loop_cond = loop_cond; - } else if (batch_iter.second->loop_cond != loop_cond) { - GELOGE(FAILED, "Multi LoopCond nodes exist."); - return FAILED; - } + for (const auto &enter_node : loop_group_iter.second->enter_nodes) { + for (const auto &out_node : enter_node->GetOutAllNodes()) { + const string &type = out_node->GetType(); + if ((type != MERGE) && (type != REFMERGE)) { + continue; + } + + NodePtr next_node = nullptr; + if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s", frame_name.c_str()); + return INTERNAL_ERROR; + } + loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); + + NodePtr switch_node = nullptr; + if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str()); + return INTERNAL_ERROR; + } + if (switch_node == nullptr) { + continue; + } + + NodePtr loop_cond = nullptr; + if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); + return INTERNAL_ERROR; + } + if (loop_group_iter.second->loop_cond == nullptr) { + loop_group_iter.second->loop_cond = loop_cond; + } else if (loop_group_iter.second->loop_cond != loop_cond) { + GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str()); + return FAILED; } } } @@ -223,18 +152,16 @@ bool NextIterationPass::VerifyWhileGroup() { GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); return false; } - for (const auto &batch_iter : loop_group_iter.second) { - if (batch_iter.second->loop_cond == nullptr) { - GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); - return false; - } + if (loop_group_iter.second->loop_cond == nullptr) { + GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); + return false; + } - for (const auto &pair_iter : batch_iter.second->merge_next_pairs) { - if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { - GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", - frame_name.c_str()); - return false; - } + for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) { + if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { + GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", + frame_name.c_str()); + return false; } } } @@ -249,56 +176,53 @@ bool NextIterationPass::VerifyWhileGroup() { /// Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { for (const auto &loop_cond_iter : loop_group_map_) { - for (const auto &batch_iter : loop_cond_iter.second) { - const std::string &cond_name = batch_iter.second->loop_cond->GetName(); - GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); - - // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge - NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); - NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); - if ((enter_active == nullptr) || (next_active == nullptr)) { - GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); + const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); + GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); + + // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge + NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); + NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); + if ((enter_active == nullptr) || (next_active == nullptr)) { + GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); + return INTERNAL_ERROR; + } + + for (const auto &enter_node : loop_cond_iter.second->enter_nodes) { + // Enter --> Active + if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add control edge from %s to %s failed.", enter_node->GetName().c_str(), + enter_active->GetName().c_str()); return INTERNAL_ERROR; } + } - for (const auto &enter_node : batch_iter.second->enter_nodes) { - // Enter --> Active - if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != - GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge failed."); - return INTERNAL_ERROR; - } + for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { + NodePtr merge_node = pair.first; + NodePtr next_node = pair.second; + // Active --> Merge + if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add control edge failed."); + return INTERNAL_ERROR; } - for (const auto &pair : batch_iter.second->merge_next_pairs) { - NodePtr merge_node = pair.first; - NodePtr next_node = pair.second; - // Active --> Merge - if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != - GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge failed."); - return INTERNAL_ERROR; - } - - // NextIteration --> Active - if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge failed."); - return INTERNAL_ERROR; - } - - // break link between NextIteration and Merge - if (BreakNextIteration(next_node, merge_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); - return INTERNAL_ERROR; - } + // NextIteration --> Active + if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add control edge failed."); + return INTERNAL_ERROR; } - if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || - (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { - GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); + // break link between NextIteration and Merge + if (BreakNextIteration(next_node, merge_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); return INTERNAL_ERROR; } } + + if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || + (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { + GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); + return INTERNAL_ERROR; + } } return SUCCESS; @@ -365,12 +289,11 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & /// @param [in] node /// @param [in] target_type /// @param [in] is_input -/// @param [in] batch_label /// @param [out] target_node /// @return Status /// Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, - const std::string &batch_label, NodePtr &target_node) { + NodePtr &target_node) { if (node == nullptr) { GELOGE(PARAM_INVALID, "node is null."); return PARAM_INVALID; @@ -387,12 +310,6 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string } for (const auto &tmp_node : nodes) { - std::string tmp_label; - (void)AttrUtils::GetStr(tmp_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); - bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); - if (need_skip) { - continue; - } const std::string type = tmp_node->GetType(); if ((target_type == LOOPCOND) && (type == target_type)) { target_node = tmp_node; @@ -415,7 +332,6 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string /// @return SUCCESS /// Status NextIterationPass::ClearStatus() { - frame_enter_map_.clear(); loop_group_map_.clear(); return SUCCESS; } diff --git a/ge/graph/passes/next_iteration_pass.h b/ge/graph/passes/next_iteration_pass.h index f8223c20..3266254d 100755 --- a/ge/graph/passes/next_iteration_pass.h +++ b/ge/graph/passes/next_iteration_pass.h @@ -46,13 +46,6 @@ class NextIterationPass : public GraphPass { /// Status GroupEnterNode(const NodePtr &enter_node); - /// - /// @brief Group Enter nodes without batch_label attr - /// @param [in] compute_graph - /// @return Status - /// - Status GroupWithNoBatch(const ComputeGraphPtr &graph); - /// /// @brief Find while groups /// @return Status @@ -97,13 +90,10 @@ class NextIterationPass : public GraphPass { /// @param [out] target_node /// @return Status /// - Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, - const std::string &batch_label, NodePtr &target_node); + Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); - // map> - std::unordered_map> frame_enter_map_; - // map> - std::unordered_map> loop_group_map_; + // map + std::unordered_map loop_group_map_; }; } // namespace ge #endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ diff --git a/ge/graph/passes/remove_same_const_pass.cc b/ge/graph/passes/remove_same_const_pass.cc new file mode 100644 index 00000000..e75a4553 --- /dev/null +++ b/ge/graph/passes/remove_same_const_pass.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "remove_same_const_pass.h" + +#include +#include +#include + +#include "common/base64.h" +#include "ge_local_engine/engine/host_cpu_engine.h" +#include "graph/utils/node_utils.h" + +namespace ge { +namespace { +std::string GetCseKey(const NodePtr &node) { + std::stringstream ss; + ss << node->GetType() << "control-inputs-"; + std::set control_in_node_names; + for (auto &src_node : node->GetInControlNodes()) { + control_in_node_names.insert(src_node->GetName()); + } + for (auto &name : control_in_node_names) { + ss << name << "-"; + } + + ss << "attrs-" << AttrUtils::GetAllAttrsStr(node->GetOpDesc()); + + return ss.str(); +} + +bool IsConstType(const NodePtr &node) { return (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP); } +} // namespace +Status RemoveSameConstPass::Run(ComputeGraphPtr graph) { + GELOGD("Begin to run RemoveSameConstPass on the graph"); + GE_CHECK_NOTNULL(graph); + std::map keys_to_node; + for (const auto &node : graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + if (!IsConstType(node)) { + continue; + } + bool is_unknown = false; + auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); + if (ret != GRAPH_SUCCESS) { + GELOGW("Get node unknown status failed, node name:%s, type:%s.", + node->GetName().c_str(), node->GetType().c_str()); + continue; + } + if (is_unknown) { + GELOGI("Current node %s, type %s is unknown shape which should be skip.", + node->GetName().c_str(), node->GetType().c_str()); + continue; + } + auto key = GetCseKey(node); + GELOGD("The const node %s cse key %s", node->GetName().c_str(), ge::base64::EncodeToBase64(key).c_str()); + auto iter = keys_to_node.find(key); + if (iter == keys_to_node.end()) { + keys_to_node[key] = node; + continue; + } + + if (node->GetAllOutDataAnchorsSize() != iter->second->GetAllOutDataAnchorsSize()) { + GELOGW("The const node %s and %s have the same CSE key, but different output anchor count, skip to fusion them", + iter->second->GetName().c_str(), node->GetName().c_str()); + continue; + } + + std::vector output_map(node->GetAllOutDataAnchorsSize()); + for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { + output_map[i] = i; + } + + ret = GraphUtils::ReplaceNodeAnchors(iter->second, node, {}, output_map); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to replace node %s by node %s", node->GetName().c_str(), + iter->second->GetName().c_str(), ret); + return INTERNAL_ERROR; + } + + NodeUtils::UnlinkAll(*node); + + ret = GraphUtils::RemoveNodeWithoutRelink(graph, node); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to remove node %s from graph", node->GetName().c_str()); + return INTERNAL_ERROR; + } + + GELOGI("Remove const node %s by RemoveSameConstPass, replace it with node %s", node->GetName().c_str(), + iter->second->GetName().c_str()); + } + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/remove_same_const_pass.h b/ge/graph/passes/remove_same_const_pass.h new file mode 100644 index 00000000..08905bd2 --- /dev/null +++ b/ge/graph/passes/remove_same_const_pass.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_ +#define GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_ + +#include "graph/types.h" +#include "inc/graph_pass.h" + +namespace ge { +class RemoveSameConstPass : public GraphPass { + public: + Status Run(ge::ComputeGraphPtr graph) override ; +}; +} // namespace ge +#endif //GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_ diff --git a/ge/graph/passes/useless_control_out_remove_pass.cc b/ge/graph/passes/useless_control_out_remove_pass.cc new file mode 100644 index 00000000..4d74d582 --- /dev/null +++ b/ge/graph/passes/useless_control_out_remove_pass.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/useless_control_out_remove_pass.h" + +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" + +namespace ge { +Status UselessControlOutRemovePass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + + if ((node->GetType() != CONSTANT) && (node->GetType() != CONSTANTOP)) { + return SUCCESS; + } + GELOGD("UselessControlOutRemovePass running, node: %s.", node->GetName().c_str()); + + // const has no control input + if (node->GetInControlNodes().empty()) { + if (node->GetOutDataNodes().empty()) { + // It is an isolated const, just remove it. + GELOGI("Delete isolated const: %s.", node->GetName().c_str()); + GE_CHK_STATUS_RET(IsolateAndDeleteNode(node, {})) + AddNodeDeleted(node); + } else { + auto out_ctrl_anchor = node->GetOutControlAnchor(); + if (out_ctrl_anchor != nullptr && !out_ctrl_anchor->GetPeerAnchors().empty()) { + GELOGI("Node: %s unlink all out control edge.", node->GetName().c_str()); + out_ctrl_anchor->UnlinkAll(); + } + } + } + + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/ge/graph/passes/useless_control_out_remove_pass.h b/ge/graph/passes/useless_control_out_remove_pass.h new file mode 100644 index 00000000..d84b918f --- /dev/null +++ b/ge/graph/passes/useless_control_out_remove_pass.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_ +#define GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_ + +#include "graph/passes/base_pass.h" + +namespace ge { +class UselessControlOutRemovePass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_ \ No newline at end of file diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index a90f145e..c8880b2e 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -44,6 +44,8 @@ using std::set; using std::string; using std::vector; +using std::map; +using std::queue; namespace ge { namespace multibatch { @@ -57,10 +59,15 @@ const int kDataInIndex = 0; const int kMergeDataOutIndex = 0; const int kStaticOutput = -1; const int kDivisionConst = 2; +const int32_t kOneInDataNode = 1; +const int32_t kFindNoMatch = 0; inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } +inline bool IsEnterType(const string &node_type) { return (node_type == ENTER) || (node_type == REFENTER); } +const set unchange_types({CONSTANT, CONSTANTOP, ENTER, REFENTER}); + inline bool IsGetNextType(const NodePtr &node) { std::string original_type; GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, @@ -218,12 +225,6 @@ Status MultiBatchGraphCopyer::CopyGraph() { return ret; } - ret = InsertIdentityAfterSwitchN(); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to insert identity nodes after switchn node."); - return INTERNAL_ERROR; - } - GELOGI("Begin to remove useless nodes by prune pass after copy process"); PrunePass prune_pass; ret = prune_pass.Run(graph_); @@ -240,6 +241,18 @@ Status MultiBatchGraphCopyer::Init() { return ret; } + ret = RelinkConstCtrlEdge(); + if (ret != SUCCESS) { + GELOGE(FAILED, "Relink const's control edge failed."); + return FAILED; + } + + ret = ExtractUnchangedStructureOutofCycle(); + if (ret != SUCCESS) { + GELOGE(FAILED, "Extract unchanged structure out of cycle failed."); + return FAILED; + } + for (auto &node : graph_->GetAllNodes()) { origin_all_nodes_.emplace_back(node); if (IsDataLikeType(node->GetType())) { @@ -252,6 +265,281 @@ Status MultiBatchGraphCopyer::Init() { return SUCCESS; } +Status MultiBatchGraphCopyer::RelinkConstCtrlEdge() { + for (auto &node : graph_->GetAllNodes()) { + GE_CHECK_NOTNULL(node); + if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) { + if (node->GetOutDataNodes().empty()) { + continue; + } + if (!node->GetInControlNodes().empty()) { + auto in_ctrl_nodes = node->GetInControlNodes(); + auto out_nodes = node->GetOutAllNodes(); + bool has_merge_out = false; + for (const auto &out_node : out_nodes) { + GE_CHECK_NOTNULL(out_node); + if (out_node->GetType() == MERGE || out_node->GetType() == REFMERGE) { + has_merge_out = true; + break; + } + } + if (has_merge_out) { + continue; + } + auto in_ctrl_anchor = node->GetInControlAnchor(); + GE_CHECK_NOTNULL(in_ctrl_anchor); + in_ctrl_anchor->UnlinkAll(); + for (auto &in_ctrl_node : in_ctrl_nodes) { + auto out_ctrl_anchor_of_in_ctrl_node = in_ctrl_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(out_ctrl_anchor_of_in_ctrl_node); + for (auto &out_node : out_nodes) { + if (IsEnterType(out_node->GetType())) { + continue; + } + if (!out_ctrl_anchor_of_in_ctrl_node->IsLinkedWith(out_node->GetInControlAnchor())) { + GE_CHK_STATUS_RET(out_ctrl_anchor_of_in_ctrl_node->LinkTo(out_node->GetInControlAnchor())) + } + } + } + } + auto out_ctrl_anchor = node->GetOutControlAnchor(); + if (out_ctrl_anchor != nullptr) { + out_ctrl_anchor->UnlinkAll(); + } + } + } + + return SUCCESS; +} + +Status MultiBatchGraphCopyer::ExtractUnchangedStructureOutofCycle() { + map> frame_enter; + if (GetEnterNodesGroupByFrame(frame_enter) != SUCCESS) { + GELOGE(FAILED, "Get enter nodes grouped by frame_name failed."); + return FAILED; + } + + queue nodes_to_extract; + if (GetNodeNeedExtract(frame_enter, nodes_to_extract) != SUCCESS) { + GELOGE(FAILED, "Get nodes needed to extract failed."); + return FAILED; + } + + while (!nodes_to_extract.empty()) { + auto node = nodes_to_extract.front(); + nodes_to_extract.pop(); + OpDescPtr enter_desc = nullptr; + if (MoveInEntersInDataAnchorDown(node, enter_desc) != SUCCESS) { + GELOGE(FAILED, "Move in enter nodes' in data anchors down of %s failed.", node->GetName().c_str()); + return FAILED; + } + set out_nodes; + if (InsertEnterAfterNode(node, enter_desc, out_nodes) != SUCCESS) { + GELOGE(FAILED, "Insert enter node after %s failed.", node->GetName().c_str()); + return FAILED; + } + + if (MoveCtrlEdgeToOutNodes(node, out_nodes) != SUCCESS) { + GELOGE(FAILED, "Move %s's control edge to out nodes failed.", node->GetName().c_str()); + return FAILED; + } + + for (auto &out_node : out_nodes) { + GE_CHECK_NOTNULL(out_node); + if (AllInDataNodesUnchangeAndNoMergeOut(out_node)) { + nodes_to_extract.push(out_node); + } + } + } + + if (DeleteEnterWithoutDataOut() != SUCCESS) { + GELOGE(FAILED, "Delete enter node without out data nodes failed."); + return FAILED; + } + + return SUCCESS; +} + +Status MultiBatchGraphCopyer::GetEnterNodesGroupByFrame(map> &frame_enter) { + for (auto &node : graph_->GetAllNodes()) { + GE_CHECK_NOTNULL(node); + if (IsEnterType(node->GetType())) { + if (!node->GetInControlNodes().empty() || !node->GetOutControlNodes().empty()) { + continue; + } + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + string frame_name; + if (!AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { + GELOGE(FAILED, "Get attr frame_name of enter[%] failed.", node->GetName().c_str()); + return FAILED; + } + frame_enter[frame_name].emplace_back(node); + } + } + + return SUCCESS; +} + +Status MultiBatchGraphCopyer::GetNodeNeedExtract(const map> &frame_enter, + queue &nodes_to_extract) { + for (const auto &one_group : frame_enter) { + auto enters = one_group.second; + for (const auto &enter : enters) { + auto out_data_nodes = enter->GetOutDataNodes(); + for (const auto &out_data_node : out_data_nodes) { + GE_CHECK_NOTNULL(out_data_node); + if (AllInDataNodesUnchangeAndNoMergeOut(out_data_node)) { + nodes_to_extract.push(out_data_node); + } + } + } + } + + return SUCCESS; +} + +bool MultiBatchGraphCopyer::AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node) { + auto out_data_nodes = node->GetOutDataNodes(); + for (const auto &out_data_node : out_data_nodes) { + if (out_data_node == nullptr) { + return false; + } + + if (out_data_node->GetType() == MERGE || out_data_node->GetType() == REFMERGE) { + return false; + } + } + + auto in_data_nodes = node->GetInDataNodes(); + if (in_data_nodes.size() == kOneInDataNode) { + return true; + } + + for (const auto &in_data_node : in_data_nodes) { + if (in_data_node == nullptr) { + return false; + } + if (unchange_types.count(in_data_node->GetType()) == kFindNoMatch) { + return false; + } + } + + return true; +} + +Status MultiBatchGraphCopyer::MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc) { + auto in_data_anchors = node->GetAllInDataAnchors(); + for (auto &in_data_anchor : in_data_anchors) { + auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_data_anchor); + auto peer_in_data_node = peer_out_data_anchor->GetOwnerNode(); + if (IsEnterType(peer_in_data_node->GetType())) { + GE_CHK_STATUS_RET(peer_out_data_anchor->Unlink(in_data_anchor)) + GELOGD("Unlink data edge from %s to %s.", peer_in_data_node->GetName().c_str(), node->GetName().c_str()); + auto enter_in_data_anchors = peer_in_data_node->GetAllInDataAnchors(); + for (auto &enter_in_data_anchor : enter_in_data_anchors) { + auto peer_out_data_anchor_of_enter = enter_in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_data_anchor_of_enter); + if (peer_out_data_anchor_of_enter->IsLinkedWith(in_data_anchor)) { + continue; + } + GE_CHK_STATUS_RET(peer_out_data_anchor_of_enter->LinkTo(in_data_anchor)) + GELOGD("Relink data edge from %s to %s.", peer_out_data_anchor_of_enter->GetOwnerNode()->GetName().c_str(), + node->GetName().c_str()); + } + enter_desc = peer_in_data_node->GetOpDesc(); + GE_CHECK_NOTNULL(enter_desc); + } + } + + return SUCCESS; +} + +Status MultiBatchGraphCopyer::InsertEnterAfterNode(NodePtr &node, const OpDescPtr ©_desc, set &out_nodes) { + if (copy_desc == nullptr) { + return SUCCESS; + } + map>> outanchors_inanchors_nodes; + auto out_data_anchors = node->GetAllOutDataAnchors(); + for (auto &out_data_anchor : out_data_anchors) { + auto peer_in_data_anchors = out_data_anchor->GetPeerInDataAnchors(); + for (auto peer_in_data_anchor : peer_in_data_anchors) { + GE_CHECK_NOTNULL(peer_in_data_anchor); + auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); + out_nodes.emplace(peer_in_data_node); + outanchors_inanchors_nodes[out_data_anchor].emplace_back(std::make_pair(peer_in_data_anchor, peer_in_data_node)); + } + } + + int32_t i = 0; + auto node_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(node_desc); + // Insert one enter node after node's per out data anchor + for (auto &outanchor_inanchors_nodes : outanchors_inanchors_nodes) { + string name = node->GetName() + "_" + ENTER + "_" + std::to_string(i++); + GELOGD("Create Enter op %s after %s.", name.c_str(), node->GetName().c_str()); + auto enter_desc = AttrUtils::CopyOpDesc(copy_desc); + enter_desc->SetName(name); + GE_CHK_STATUS_RET( + enter_desc->UpdateInputDesc("x", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx()))) + GE_CHK_STATUS_RET( + enter_desc->UpdateOutputDesc("y", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx()))) + auto enter_node = graph_->AddNode(enter_desc); + GE_CHECK_NOTNULL(enter_node); + GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->LinkTo(enter_node->GetInDataAnchor(kDataInIndex))) + GE_CHECK_NOTNULL(enter_node->GetOutDataAnchor(kDataInIndex)); + for (auto &inanchor_node : outanchor_inanchors_nodes.second) { + GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->Unlink(inanchor_node.first)) + GE_CHK_STATUS_RET(enter_node->GetOutDataAnchor(kDataInIndex)->LinkTo(inanchor_node.first)) + GELOGD("Unlink from %s to %s, link from %s to %s then to %s.", node->GetName().c_str(), + inanchor_node.second->GetName().c_str(), node->GetName().c_str(), enter_node->GetName().c_str(), + inanchor_node.second->GetName().c_str()); + } + } + + return SUCCESS; +} + +// Move node's in control edges to out data nodes +Status MultiBatchGraphCopyer::MoveCtrlEdgeToOutNodes(NodePtr &node, set &out_nodes) { + auto in_ctrl_anchor = node->GetInControlAnchor(); + GE_CHECK_NOTNULL(in_ctrl_anchor); + auto peer_out_ctrl_anchors = in_ctrl_anchor->GetPeerOutControlAnchors(); + for (auto &peer_out_ctrl_anchor : peer_out_ctrl_anchors) { + GE_CHK_STATUS_RET(peer_out_ctrl_anchor->Unlink(in_ctrl_anchor)) + GELOGD("Unlink control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), + node->GetName().c_str()); + for (auto &out_node : out_nodes) { + auto in_ctrl_anchor_of_out_node = out_node->GetInControlAnchor(); + GE_CHECK_NOTNULL(in_ctrl_anchor_of_out_node); + if (!peer_out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor_of_out_node)) { + GE_CHK_STATUS_RET(peer_out_ctrl_anchor->LinkTo(in_ctrl_anchor_of_out_node)) + GELOGD("Link control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), + out_node->GetName().c_str()); + } + } + } + + return SUCCESS; +} + +Status MultiBatchGraphCopyer::DeleteEnterWithoutDataOut() { + for (auto &node : graph_->GetAllNodes()) { + GE_CHECK_NOTNULL(node); + if (IsEnterType(node->GetType())) { + auto out_nodes = node->GetOutAllNodes(); + if (out_nodes.empty()) { + GELOGD("Delete enter node: %s which has no output.", node->GetName().c_str()); + GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {})) + GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node)) + } + } + } + + return SUCCESS; +} + void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) { auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(), @@ -297,6 +585,9 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { LabelStatusForGetNextSink(data); } } + + map> frame_enters; + InitStatus(frame_enters); bool changed = true; // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch while (changed) { @@ -306,12 +597,13 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { if (iter != origin_nodes_status_.end()) { continue; } - for (auto &in_node : node->GetInAllNodes()) { - bool is_in_batch = origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() && - origin_nodes_status_[in_node.get()] == kNodeInBatchBranch; - if (is_in_batch) { - origin_nodes_status_[node.get()] = kNodeInBatchBranch; - changed = true; + for (auto &in_node : node->GetInDataNodes()) { + if (origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end()) { + if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) { + origin_nodes_status_[node.get()] == kNodeInBatchBranch; + ResetEnterStatus(frame_enters, node); + changed = true; + } break; } } @@ -320,6 +612,45 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { return SUCCESS; } +void MultiBatchGraphCopyer::InitStatus(map> &frame_enters) { + for (const auto &node : origin_all_nodes_) { + if (!IsEnterType(node->GetType())) { + continue; + } + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + string frame_name; + if (AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { + frame_enters[frame_name].emplace_back(node); + } + } + + for (const auto &data : origin_data_nodes_) { + auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); + if (!IsAllDimsPositive(data_shape.GetDims())) { + origin_nodes_status_[data.get()] = kNodeInBatchBranch; + } + } +} + +void MultiBatchGraphCopyer::ResetEnterStatus(map> &frame_enters, const NodePtr &node) { + if (!IsEnterType(node->GetType())) { + return; + } + + for (const auto &frame_enter : frame_enters) { + auto &enters = frame_enter.second; + if (std::find(enters.begin(), enters.end(), node) != enters.end()) { + for (const auto &enter : enters) { + origin_nodes_status_[enter.get()] = kNodeInBatchBranch; + } + break; + } + } +} + Status MultiBatchGraphCopyer::LabelStatus() { if (LabelInBatchBranchStatus() != SUCCESS) { GELOGE(PARAM_INVALID, "Failed to label no in batch branch"); @@ -1360,52 +1691,6 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { return SUCCESS; } -Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() { - for (auto &node : graph_->GetAllNodes()) { - if (node->GetType() != SWITCHN) { - continue; - } - auto switchn_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(switchn_desc); - size_t i = 0; - for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - auto out_node = in_data_anchor->GetOwnerNode(); - auto op_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if ((out_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { - GELOGD("No need to insert identity between %s and %s.", node->GetName().c_str(), out_node->GetName().c_str()); - continue; - } - - auto identity_desc = MakeShared(node->GetName() + "_identity_" + std::to_string(i), IDENTITY); - GE_CHECK_NOTNULL(identity_desc); - - string batch_label; - if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { - if (!AttrUtils::SetStr(identity_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { - GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", identity_desc->GetName().c_str()); - return FAILED; - } - } - - auto data_desc = switchn_desc->GetOutputDesc(i); - i++; - GE_CHK_STATUS_RET(identity_desc->AddInputDesc("x", data_desc)); - GE_CHK_STATUS_RET(identity_desc->AddOutputDesc("y", data_desc)); - - auto identity_node = graph_->AddNode(identity_desc); - GE_CHECK_NOTNULL(identity_node); - GE_CHK_STATUS_RET(out_data_anchor->LinkTo(identity_node->GetInDataAnchor(0))); - GE_CHECK_NOTNULL(identity_node->GetOutControlAnchor()); - GE_CHK_STATUS_RET(identity_node->GetOutControlAnchor()->LinkTo(out_node->GetInControlAnchor())); - } - } - } - - return SUCCESS; -} - Status ProcessMultiBatch(ComputeGraphPtr &graph) { if (GetLocalOmgContext().dynamic_node_type.empty()) { const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); diff --git a/ge/graph/preprocess/multi_batch_copy_graph.h b/ge/graph/preprocess/multi_batch_copy_graph.h index a0de4413..d51c4c02 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.h +++ b/ge/graph/preprocess/multi_batch_copy_graph.h @@ -18,6 +18,7 @@ #include #include #include +#include #include "external/ge/ge_api_error_codes.h" @@ -64,12 +65,26 @@ class MultiBatchGraphCopyer { private: Status Init(); Status CheckArguments(); + Status RelinkConstCtrlEdge(); + + Status ExtractUnchangedStructureOutofCycle(); + Status GetEnterNodesGroupByFrame(std::map> &frame_enter); + Status GetNodeNeedExtract(const std::map> &frame_enter, + std::queue &nodes_to_extract); + bool AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node); + Status MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc); + Status InsertEnterAfterNode(NodePtr &node, const OpDescPtr &enter_desc, std::set &out_nodes); + Status MoveCtrlEdgeToOutNodes(NodePtr &node, std::set &out_nodes); + Status DeleteEnterWithoutDataOut(); // label status for origin_all_nodes_ Status LabelStatus(); Status LabelInBatchBranchStatus(); void LabelStatusForData(const NodePtr &data); void LabelStatusForGetNextSink(const NodePtr &data); + void InitStatus(std::map> &frame_enters); + void ResetEnterStatus(std::map> &frame_enters, const NodePtr &node); + // add nodes functions Status CreateNewNodes(); @@ -81,7 +96,6 @@ class MultiBatchGraphCopyer { Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index, std::vector> &dynamic_out_to_switchn); - Status InsertIdentityAfterSwitchN(); Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index); Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index); diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 6fad46bf..8eec3df6 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -245,6 +245,8 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" "${GE_CODE_DIR}/ge/model/ge_model.cc" "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" "${GE_CODE_DIR}/ge/graph/load/new_model_manager/model_utils.cc" @@ -475,6 +477,8 @@ set(GRAPH_PASS_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc"