From f23a4de0e0e444d486884b0efba55cfbb2fbb95d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B6=9B?= Date: Thu, 10 Dec 2020 16:55:32 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!536=20:?= =?UTF-8?q?=20Decrease=20transformer's=20om=20size=20in=20dynamic=20dims?= =?UTF-8?q?=20scenario'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ge/CMakeLists.txt | 1 - ge/ge_inference.mk | 1 - ge/ge_runner.mk | 1 - ge/graph/build/model_builder.cc | 1 - ge/graph/manager/graph_manager.cc | 5 - ge/graph/passes/attach_stream_label_pass.cc | 28 +- ge/graph/passes/attach_stream_label_pass.h | 4 +- ge/graph/passes/base_pass.cc | 2 +- .../common_subexpression_elimination_pass.cc | 3 +- ge/graph/passes/const_pass.cc | 55 --- ge/graph/passes/const_pass.h | 29 -- ge/graph/passes/dimension_adjust_pass.cc | 64 --- ge/graph/passes/dimension_adjust_pass.h | 4 - ge/graph/passes/enter_pass.cc | 48 +-- ge/graph/passes/enter_pass.h | 3 +- ge/graph/passes/folding_pass.cc | 5 +- ge/graph/passes/merge_to_stream_merge_pass.cc | 10 + ge/graph/passes/next_iteration_pass.cc | 262 ++++++++---- ge/graph/passes/next_iteration_pass.h | 16 +- ge/graph/preprocess/multi_batch_copy_graph.cc | 401 +++--------------- ge/graph/preprocess/multi_batch_copy_graph.h | 16 +- 21 files changed, 295 insertions(+), 664 deletions(-) delete mode 100644 ge/graph/passes/const_pass.cc delete mode 100644 ge/graph/passes/const_pass.h diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index b037f4a4..88a5c52f 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -154,7 +154,6 @@ set(TRAIN_SRC_LIST "graph/passes/compile_nodes_pass.cc" "graph/passes/constant_folding_pass.cc" "graph/passes/constant_fuse_same_pass.cc" - "graph/passes/control_trigger_pass.cc" "graph/passes/dimension_adjust_pass.cc" "graph/passes/dimension_compute_pass.cc" diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index fe76a612..0987f148 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -189,7 +189,6 @@ OMG_HOST_SRC_FILES := \ graph/passes/control_trigger_pass.cc \ graph/passes/cond_pass.cc \ graph/passes/cond_remove_pass.cc \ - graph/passes/const_pass.cc \ graph/passes/for_pass.cc \ graph/passes/enter_pass.cc \ graph/passes/assign_pass.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 58ad1266..a2679ed1 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -123,7 +123,6 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/compile_nodes_pass.cc \ graph/passes/constant_folding_pass.cc \ graph/passes/constant_fuse_same_pass.cc \ - graph/passes/const_pass.cc \ graph/passes/control_trigger_pass.cc \ graph/passes/dimension_adjust_pass.cc \ graph/passes/dimension_compute_pass.cc \ diff --git a/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc index 3be45895..37eb499a 100755 --- a/ge/graph/build/model_builder.cc +++ b/ge/graph/build/model_builder.cc @@ -224,7 +224,6 @@ Status ModelBuilder::AdjustConstWeightSize(const ge::NodePtr &node, size_t &mem_ GeTensorDesc &tensor_desc = weight->MutableTensorDesc(); size_t output_size = weight->GetData().size(); TensorUtils::SetDataOffset(tensor_desc, mem_offset); - GELOGD("Node: %s, weight size: %zu.", node->GetName().c_str(), output_size); mem_offset += output_size; } return SUCCESS; diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index d4c6ca8d..9ce68d76 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -56,7 +56,6 @@ #include "graph/passes/cond_remove_pass.h" #include "graph/passes/constant_folding_pass.h" #include "graph/passes/constant_fuse_same_pass.h" -#include "graph/passes/const_pass.cc" #include "graph/passes/control_trigger_pass.h" #include "graph/passes/ctrl_edge_transfer_pass.h" #include "graph/passes/dimension_adjust_pass.h" @@ -2138,7 +2137,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { TransposeTransDataPass transpose_transdata_pass; TransOpSymmetryEliminationPass symmetry_elimination_pass; DimensionComputePass dimension_compute_pass; - ConstPass const_pass; names_to_passes.emplace_back("EnterPass", &enter_pass); names_to_passes.emplace_back("AddNPass", &addn_pass); names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); @@ -2152,7 +2150,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); - names_to_passes.emplace_back("ConstPass", &const_pass); GE_TIMESTAMP_START(names_to_passes); ret = GEPass(compute_graph).Run(names_to_passes); GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2"); @@ -2193,8 +2190,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass", new (std::nothrow) VariableRefUselessControlOutDeletePass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass)) - GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::CommonSubexpressionEliminationPass", - new (std::nothrow) CommonSubexpressionEliminationPass)); if (options_.train_graph_flag) { // Priority: The GlobalStepInsertPass should work before graph partitioner. // Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory diff --git a/ge/graph/passes/attach_stream_label_pass.cc b/ge/graph/passes/attach_stream_label_pass.cc index cd3509c7..c0e0f669 100644 --- a/ge/graph/passes/attach_stream_label_pass.cc +++ b/ge/graph/passes/attach_stream_label_pass.cc @@ -18,8 +18,6 @@ #include "ge/ge_api_types.h" #include "graph/common/omg_util.h" -using std::string; - namespace ge { Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { GELOGD("AttachStreamLabelPass Enter."); @@ -189,10 +187,21 @@ Status AttachStreamLabelPass::UpdateEnterNode() { } std::stack enter_nodes; + std::string batch_label; for (const auto &enter_node : pair.second) { enter_nodes.emplace(enter_node); + std::string tmp_label; + (void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); + if (!tmp_label.empty()) { + if (batch_label.empty()) { + batch_label = tmp_label; + } else if (batch_label != tmp_label) { + GELOGE(FAILED, "multi batch_label exist, label1=%s, label2=%s.", batch_label.c_str(), tmp_label.c_str()); + return FAILED; + } + } } - if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { + if (UpdateLoopBranch(enter_nodes, active_label_list[0], batch_label) != SUCCESS) { GELOGE(FAILED, "Update stream_label for loop_branch failed."); return FAILED; } @@ -217,7 +226,10 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_no } for (const auto &enter_node : enter_nodes) { - GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); + GE_CHECK_NOTNULL(enter_node->GetOpDesc()); + if (enter_node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL)) { + GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); + } } return SUCCESS; } @@ -229,7 +241,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_no /// @param [in] batch_label /// @return Status /// -Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_nodes, const string &stream_label) { +Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label, + const std::string &batch_label) { std::stack nodes(enter_nodes); NodePtr cur_node = nullptr; while (!nodes.empty()) { @@ -238,6 +251,11 @@ Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_ for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { OpDescPtr out_desc = out_node->GetOpDesc(); GE_CHECK_NOTNULL(out_desc); + std::string tmp_label; + (void)AttrUtils::GetStr(out_desc, ATTR_NAME_BATCH_LABEL, tmp_label); + if (!tmp_label.empty() && (tmp_label != batch_label)) { + continue; + } std::string out_type = out_desc->GetType(); bool need_skip = out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) || diff --git a/ge/graph/passes/attach_stream_label_pass.h b/ge/graph/passes/attach_stream_label_pass.h index ad71d58f..19f11480 100755 --- a/ge/graph/passes/attach_stream_label_pass.h +++ b/ge/graph/passes/attach_stream_label_pass.h @@ -58,9 +58,11 @@ class AttachStreamLabelPass : public GraphPass { /// @brief Update stream_label for loop_branch /// @param [in] enter_nodes /// @param [in] stream_label + /// @param [in] batch_label /// @return Status /// - static Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label); + static Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label, + const std::string &batch_label); /// /// @brief Update stream_label start with enter nodes diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index 8d0bcf25..68efbeb9 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -96,7 +96,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder node->GetName().c_str(), node->GetType().c_str()); continue; } - if (node_to_re_pass->IsAllInNodesSeen(nodes_seen) || node_to_re_pass->GetType() == ENTER) { + if (node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str()); nodes_re_pass.insert(node_to_re_pass); } else { diff --git a/ge/graph/passes/common_subexpression_elimination_pass.cc b/ge/graph/passes/common_subexpression_elimination_pass.cc index 9e771b65..a4662d5d 100644 --- a/ge/graph/passes/common_subexpression_elimination_pass.cc +++ b/ge/graph/passes/common_subexpression_elimination_pass.cc @@ -58,8 +58,7 @@ std::string GetCseKey(const NodePtr &node) { /// To avoid delete wrong nodes(e.g. stateful nodes), /// only nodes have folding kernel will be considered for the CSE process bool IsNodeSupportCse(const NodePtr &node) { - if (HostCpuEngine::CheckSupported(NodeUtils::GetNodeType(*node)) || node->GetType() == CONSTANT || - node->GetType() == CONSTANTOP) { + if (HostCpuEngine::CheckSupported(NodeUtils::GetNodeType(*node))) { return true; } return folding_pass::GetKernelByType(node) != nullptr; diff --git a/ge/graph/passes/const_pass.cc b/ge/graph/passes/const_pass.cc deleted file mode 100644 index 42b3c23f..00000000 --- a/ge/graph/passes/const_pass.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/const_pass.h" - -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" - -namespace ge { -Status ConstPass::Run(NodePtr &node) { - GE_CHECK_NOTNULL(node); - - if ((node->GetType() != CONSTANT) && (node->GetType() != CONSTANTOP)) { - return SUCCESS; - } - GELOGD("ConstPass running, node: %s.", node->GetName().c_str()); - - // const has no control input - if (node->GetInControlNodes().empty()) { - auto out_ctrl_anchor = node->GetOutControlAnchor(); - if (out_ctrl_anchor != nullptr) { - GELOGD("Node: %s unlink all out control edge.", node->GetName().c_str()); - out_ctrl_anchor->UnlinkAll(); - } - - if (node->GetOutAllNodes().empty()) { - // it is an isolated const, just remove it. - GELOGD("Delete isolated const: %s.", node->GetName().c_str()); - auto graph = node->GetOwnerComputeGraph(); - if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Remove const %s failed.", node->GetName().c_str()); - return FAILED; - } - AddNodeDeleted(node); - } - } - - return SUCCESS; -} -} // namespace ge \ No newline at end of file diff --git a/ge/graph/passes/const_pass.h b/ge/graph/passes/const_pass.h deleted file mode 100644 index a7e011ec..00000000 --- a/ge/graph/passes/const_pass.h +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_CONST_PASS_H_ -#define GE_GRAPH_PASSES_CONST_PASS_H_ - -#include "graph/passes/base_pass.h" - -namespace ge { -class ConstPass : public BaseNodePass { - public: - Status Run(NodePtr &node) override; -}; -} // namespace ge - -#endif // GE_GRAPH_PASSES_CONST_PASS_H_ \ No newline at end of file diff --git a/ge/graph/passes/dimension_adjust_pass.cc b/ge/graph/passes/dimension_adjust_pass.cc index bfb9cb4f..fc5fe69f 100755 --- a/ge/graph/passes/dimension_adjust_pass.cc +++ b/ge/graph/passes/dimension_adjust_pass.cc @@ -80,71 +80,7 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { } } - ret = DealWithInNodes(node); - if (ret != SUCCESS) { - GELOGE(ret, "DealWithInNodes of %s failed.", node->GetName().c_str()); - return ret; - } - std::vector data_relink_io_map = {kDataInputIndex}; return IsolateAndDeleteNode(node, data_relink_io_map); } - -Status DimensionAdjustPass::DealWithInNodes(NodePtr &node) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - auto graph = node->GetOwnerComputeGraph(); - auto in_data_anchors = node->GetAllInDataAnchors(); - for (auto &in_data_anchor : in_data_anchors) { - if (in_data_anchor == nullptr) { - continue; - } - auto in_node_anchor = in_data_anchor->GetPeerOutAnchor(); - if (in_node_anchor == nullptr) { - continue; - } - auto in_node = in_node_anchor->GetOwnerNode(); - if (in_node->GetType() == SWITCHN) { - GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); - auto identity_name = node->GetName() + "_ctrl_identity_" + std::to_string(in_data_anchor->GetIdx()); - auto identity = - AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph); - GE_CHECK_NOTNULL(identity); - GELOGI("Create new identity node[%s] success.", identity->GetName().c_str()); - GE_CHK_STATUS_RET(GraphUtils::AddEdge(in_node_anchor, identity->GetInDataAnchor(0))) - GE_CHECK_NOTNULL(identity->GetOutControlAnchor()); - if (identity->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor())) { - continue; - } - GE_CHK_STATUS_RET(GraphUtils::AddEdge(identity->GetOutControlAnchor(), node->GetInControlAnchor())) - } - } - - return SUCCESS; -} - -NodePtr DimensionAdjustPass::AddIdentityNodeToGraph(const string &name, const GeTensorDesc &tensor, - ComputeGraphPtr &graph) { - if (graph == nullptr) { - GELOGE(INTERNAL_ERROR, "Comput graph ptr is null in creating identity node."); - return nullptr; - } - - OpDescPtr desc = MakeShared("", ""); - if (desc == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to create op desc."); - return nullptr; - } - - desc->SetName(name); - desc->SetType(IDENTITY); - auto ret = desc->AddInputDesc(tensor); - auto ret2 = desc->AddOutputDesc(tensor); - if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) { - GELOGE(INTERNAL_ERROR, "Failed to add input/output desc in creating identity."); - return nullptr; - } - - return graph->AddNodeFront(desc); -} } // namespace ge diff --git a/ge/graph/passes/dimension_adjust_pass.h b/ge/graph/passes/dimension_adjust_pass.h index 7766f140..685d9694 100755 --- a/ge/graph/passes/dimension_adjust_pass.h +++ b/ge/graph/passes/dimension_adjust_pass.h @@ -34,10 +34,6 @@ namespace ge { class DimensionAdjustPass : public BaseNodePass { public: Status Run(ge::NodePtr &node) override; - - private: - Status DealWithInNodes(ge::NodePtr &node); - NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tensor, ComputeGraphPtr &graph); }; } // namespace ge diff --git a/ge/graph/passes/enter_pass.cc b/ge/graph/passes/enter_pass.cc index 20e60403..afeca78f 100644 --- a/ge/graph/passes/enter_pass.cc +++ b/ge/graph/passes/enter_pass.cc @@ -23,7 +23,6 @@ namespace { const size_t kOutNodesNum = 1; -const size_t kInCtrlNodesNum = 1; } namespace ge { @@ -56,7 +55,6 @@ Status EnterPass::Run(NodePtr &node) { if (out_ctrl_node == nullptr) { continue; } - GELOGD("Remove control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str()); if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) { GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(), out_ctrl_node->GetName().c_str()); @@ -64,12 +62,8 @@ Status EnterPass::Run(NodePtr &node) { } } } else { - if (OptimizeEnterWithOnlyOutData(node, in_node) != SUCCESS) { - GELOGE(FAILED, "Optimize enter node[%s] with only out data node failed.", node->GetName().c_str()); - return FAILED; - } - if (UnlinkCtrlEdgeBeforeConst(node) != SUCCESS) { - GELOGE(FAILED, "Unlink control edge before const of node[%s]'s out nodes failed.", node->GetName().c_str()); + if (OptimizeEnter(node, in_node) != SUCCESS) { + GELOGE(FAILED, "Optimize enter node[%s] failed.", node->GetName().c_str()); return FAILED; } } @@ -78,7 +72,7 @@ Status EnterPass::Run(NodePtr &node) { return SUCCESS; } -Status EnterPass::OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node) { +Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { return SUCCESS; } @@ -89,45 +83,17 @@ Status EnterPass::OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node) } GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); - GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))) + GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); const auto &out_data_anchor = node->GetOutDataAnchor(0); GE_CHECK_NOTNULL(out_data_anchor); for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)) - GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)) + GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); + GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); } - GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)) + GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)); AddNodeDeleted(node); AddRePassNodesWithInOut(in_node); return SUCCESS; } - -Status EnterPass::UnlinkCtrlEdgeBeforeConst(NodePtr &node) { - auto out_ctrl_nodes = node->GetOutControlNodes(); - if (out_ctrl_nodes.empty()) { - return SUCCESS; - } - auto out_ctrl_anchor = node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(out_ctrl_anchor); - - for (auto &out_ctrl_node : out_ctrl_nodes) { - GE_CHECK_NOTNULL(out_ctrl_node); - if ((out_ctrl_node->GetType() != CONSTANT) && (out_ctrl_node->GetType() != CONSTANTOP)) { - continue; - } - auto in_ctrl_nodes = out_ctrl_node->GetInControlNodes(); - if (in_ctrl_nodes.size() != kInCtrlNodesNum) { - continue; - } - GE_CHK_STATUS_RET(out_ctrl_anchor->Unlink(out_ctrl_node->GetInControlAnchor())) - auto out_nodes_of_const = out_ctrl_node->GetOutAllNodes(); - for (auto &out_node_of_const : out_nodes_of_const) { - if (!out_ctrl_anchor->IsLinkedWith(out_node_of_const->GetInControlAnchor())) { - GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(out_node_of_const->GetInControlAnchor())) - } - } - } - return SUCCESS; -} } // namespace ge diff --git a/ge/graph/passes/enter_pass.h b/ge/graph/passes/enter_pass.h index 67366297..677516ff 100644 --- a/ge/graph/passes/enter_pass.h +++ b/ge/graph/passes/enter_pass.h @@ -25,8 +25,7 @@ class EnterPass : public BaseNodePass { Status Run(NodePtr &node) override; private: - Status OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node); - Status UnlinkCtrlEdgeBeforeConst(NodePtr &node); + Status OptimizeEnter(NodePtr &node, NodePtr &in_node); }; } // namespace ge #endif // GE_GRAPH_PASSES_ENTER_PASS_H_ diff --git a/ge/graph/passes/folding_pass.cc b/ge/graph/passes/folding_pass.cc index 227a0f61..93dc2c40 100755 --- a/ge/graph/passes/folding_pass.cc +++ b/ge/graph/passes/folding_pass.cc @@ -173,7 +173,10 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) { continue; } auto in_node = in_node_anchor->GetOwnerNode(); - if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH) || (in_node->GetType() == SWITCHN)) { + if (in_node == nullptr) { + continue; + } + if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) { GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); auto ret = in_node_anchor->Unlink(in_data_anchor); if (ret != SUCCESS) { diff --git a/ge/graph/passes/merge_to_stream_merge_pass.cc b/ge/graph/passes/merge_to_stream_merge_pass.cc index c1a57a61..103fbb1b 100644 --- a/ge/graph/passes/merge_to_stream_merge_pass.cc +++ b/ge/graph/passes/merge_to_stream_merge_pass.cc @@ -89,6 +89,16 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); } + if (merge_op_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) { + string batch_label; + (void)AttrUtils::GetStr(merge_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); + if (!batch_label.empty()) { + auto stream_merge_desc = stream_merge->GetOpDesc(); + GE_CHECK_NOTNULL(stream_merge_desc); + (void)AttrUtils::SetStr(stream_merge_desc, ATTR_NAME_BATCH_LABEL, batch_label); + } + } + return AddActiveNodes(graph, stream_merge); } diff --git a/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc index cf46f09d..d8c4779d 100644 --- a/ge/graph/passes/next_iteration_pass.cc +++ b/ge/graph/passes/next_iteration_pass.cc @@ -19,8 +19,6 @@ #include "common/ge/ge_util.h" #include "graph/common/omg_util.h" -using std::string; - namespace ge { Status NextIterationPass::Run(ComputeGraphPtr graph) { GELOGD("NextIterationPass Enter"); @@ -37,6 +35,10 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { return INTERNAL_ERROR; } } + if (GroupWithNoBatch(graph) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Group enter_nodes failed without batch_label attr."); + return INTERNAL_ERROR; + } if (FindWhileGroups() != SUCCESS) { GELOGE(INTERNAL_ERROR, "Find while groups failed."); @@ -71,22 +73,75 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { return FAILED; } - string batch_label; - if (ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { - frame_name += batch_label; + std::string batch_label; + (void)ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label); + if (batch_label.empty()) { + auto frame_iter = frame_enter_map_.find(frame_name); + if (frame_iter == frame_enter_map_.end()) { + std::vector enter_nodes; + enter_nodes.emplace_back(enter_node); + frame_enter_map_[frame_name] = enter_nodes; + } else { + frame_iter->second.emplace_back(enter_node); + } + return SUCCESS; } - auto iter = loop_group_map_.find(frame_name); - if (iter == loop_group_map_.end()) { + auto group_iter = loop_group_map_.find(frame_name); + if (group_iter == loop_group_map_.end()) { LoopCondGroupPtr loop_group = MakeShared(); if (loop_group == nullptr) { GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); return FAILED; } loop_group->enter_nodes.emplace_back(enter_node); - loop_group_map_[frame_name] = loop_group; + loop_group_map_[frame_name][batch_label] = loop_group; } else { - iter->second->enter_nodes.emplace_back(enter_node); + auto batch_iter = group_iter->second.find(batch_label); + if (batch_iter == group_iter->second.end()) { + LoopCondGroupPtr loop_group = MakeShared(); + if (loop_group == nullptr) { + GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); + return FAILED; + } + loop_group->enter_nodes.emplace_back(enter_node); + group_iter->second[batch_label] = loop_group; + } else { + batch_iter->second->enter_nodes.emplace_back(enter_node); + } + } + + return SUCCESS; +} + +/// +/// @brief Group Enter nodes without batch_label attr +/// @param [in] compute_graph +/// @return Status +/// +Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) { + if (frame_enter_map_.empty()) { + GELOGI("All enter nodes in graph %s has batch_label attr.", graph->GetName().c_str()); + return SUCCESS; + } + for (const auto &item : frame_enter_map_) { + const std::string &frame_name = item.first; + auto iter = loop_group_map_.find(frame_name); + if (iter == loop_group_map_.end()) { + LoopCondGroupPtr loop_group = MakeShared(); + if (loop_group == nullptr) { + GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); + return FAILED; + } + loop_group->enter_nodes = item.second; + loop_group_map_[frame_name][""] = loop_group; + } else { + for (auto &batch_item : iter->second) { + for (const auto &enter_node : item.second) { + batch_item.second->enter_nodes.emplace_back(enter_node); + } + } + } } return SUCCESS; @@ -99,39 +154,55 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { Status NextIterationPass::FindWhileGroups() { for (const auto &loop_group_iter : loop_group_map_) { const std::string &frame_name = loop_group_iter.first; - for (const auto &enter_node : loop_group_iter.second->enter_nodes) { - for (const auto &out_node : enter_node->GetOutAllNodes()) { - const string &type = out_node->GetType(); - if ((type != MERGE) && (type != REFMERGE)) { - continue; - } - - NodePtr next_node = nullptr; - if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s", frame_name.c_str()); - return INTERNAL_ERROR; - } - loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); - - NodePtr switch_node = nullptr; - if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str()); - return INTERNAL_ERROR; - } - if (switch_node == nullptr) { - continue; - } - - NodePtr loop_cond = nullptr; - if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); - return INTERNAL_ERROR; - } - if (loop_group_iter.second->loop_cond == nullptr) { - loop_group_iter.second->loop_cond = loop_cond; - } else if (loop_group_iter.second->loop_cond != loop_cond) { - GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str()); - return FAILED; + for (const auto &batch_iter : loop_group_iter.second) { + const std::string &batch_label = batch_iter.first; + for (const auto &enter_node : batch_iter.second->enter_nodes) { + for (const auto &out_node : enter_node->GetOutAllNodes()) { + GELOGI("Find while_group for enter_node %s, frame_name:%s, batch_label:%s.", enter_node->GetName().c_str(), + frame_name.c_str(), batch_label.c_str()); + if ((out_node->GetType() != MERGE) && (out_node->GetType() != REFMERGE)) { + continue; + } + std::string tmp_label; + GE_CHECK_NOTNULL(out_node->GetOpDesc()); + (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); + bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); + if (need_skip) { + continue; + } + + NodePtr next_node = nullptr; + if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, + "Get NextIteration node failed: inputs of Merge should be Enter/NextIteration, current_Merge=%s", + out_node->GetName().c_str()); + return INTERNAL_ERROR; + } + batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); + + NodePtr switch_node = nullptr; + if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get Switch node failed: output of Merge should be Switch, current_Merge=%s", + out_node->GetName().c_str()); + return INTERNAL_ERROR; + } + if (switch_node == nullptr) { + continue; + } + + NodePtr loop_cond = nullptr; + if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) { + GELOGE(INTERNAL_ERROR, + "Get LoopCond node failed: pred input of Switch should be LoopCond, current_Switch=%s", + switch_node->GetName().c_str()); + return INTERNAL_ERROR; + } + if (batch_iter.second->loop_cond == nullptr) { + batch_iter.second->loop_cond = loop_cond; + } else if (batch_iter.second->loop_cond != loop_cond) { + GELOGE(FAILED, "Multi LoopCond nodes exist."); + return FAILED; + } } } } @@ -152,17 +223,19 @@ bool NextIterationPass::VerifyWhileGroup() { GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); return false; } - if (loop_group_iter.second->loop_cond == nullptr) { - GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); - return false; - } - - for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) { - if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { - GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", - frame_name.c_str()); + for (const auto &batch_iter : loop_group_iter.second) { + if (batch_iter.second->loop_cond == nullptr) { + GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); return false; } + + for (const auto &pair_iter : batch_iter.second->merge_next_pairs) { + if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { + GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", + frame_name.c_str()); + return false; + } + } } } @@ -176,53 +249,56 @@ bool NextIterationPass::VerifyWhileGroup() { /// Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { for (const auto &loop_cond_iter : loop_group_map_) { - const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); - GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); - - // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge - NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); - NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); - if ((enter_active == nullptr) || (next_active == nullptr)) { - GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); - return INTERNAL_ERROR; - } - - for (const auto &enter_node : loop_cond_iter.second->enter_nodes) { - // Enter --> Active - if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge from %s to %s failed.", enter_node->GetName().c_str(), - enter_active->GetName().c_str()); + for (const auto &batch_iter : loop_cond_iter.second) { + const std::string &cond_name = batch_iter.second->loop_cond->GetName(); + GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); + + // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge + NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); + NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); + if ((enter_active == nullptr) || (next_active == nullptr)) { + GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); return INTERNAL_ERROR; } - } - for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { - NodePtr merge_node = pair.first; - NodePtr next_node = pair.second; - // Active --> Merge - if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge failed."); - return INTERNAL_ERROR; + for (const auto &enter_node : batch_iter.second->enter_nodes) { + // Enter --> Active + if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != + GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add control edge failed."); + return INTERNAL_ERROR; + } } - // NextIteration --> Active - if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge failed."); - return INTERNAL_ERROR; + for (const auto &pair : batch_iter.second->merge_next_pairs) { + NodePtr merge_node = pair.first; + NodePtr next_node = pair.second; + // Active --> Merge + if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != + GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add control edge failed."); + return INTERNAL_ERROR; + } + + // NextIteration --> Active + if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add control edge failed."); + return INTERNAL_ERROR; + } + + // break link between NextIteration and Merge + if (BreakNextIteration(next_node, merge_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); + return INTERNAL_ERROR; + } } - // break link between NextIteration and Merge - if (BreakNextIteration(next_node, merge_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); + if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || + (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { + GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); return INTERNAL_ERROR; } } - - if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || - (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { - GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); - return INTERNAL_ERROR; - } } return SUCCESS; @@ -289,11 +365,12 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & /// @param [in] node /// @param [in] target_type /// @param [in] is_input +/// @param [in] batch_label /// @param [out] target_node /// @return Status /// Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, - NodePtr &target_node) { + const std::string &batch_label, NodePtr &target_node) { if (node == nullptr) { GELOGE(PARAM_INVALID, "node is null."); return PARAM_INVALID; @@ -310,6 +387,12 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string } for (const auto &tmp_node : nodes) { + std::string tmp_label; + (void)AttrUtils::GetStr(tmp_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); + bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); + if (need_skip) { + continue; + } const std::string type = tmp_node->GetType(); if ((target_type == LOOPCOND) && (type == target_type)) { target_node = tmp_node; @@ -332,6 +415,7 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string /// @return SUCCESS /// Status NextIterationPass::ClearStatus() { + frame_enter_map_.clear(); loop_group_map_.clear(); return SUCCESS; } diff --git a/ge/graph/passes/next_iteration_pass.h b/ge/graph/passes/next_iteration_pass.h index 3266254d..f8223c20 100755 --- a/ge/graph/passes/next_iteration_pass.h +++ b/ge/graph/passes/next_iteration_pass.h @@ -46,6 +46,13 @@ class NextIterationPass : public GraphPass { /// Status GroupEnterNode(const NodePtr &enter_node); + /// + /// @brief Group Enter nodes without batch_label attr + /// @param [in] compute_graph + /// @return Status + /// + Status GroupWithNoBatch(const ComputeGraphPtr &graph); + /// /// @brief Find while groups /// @return Status @@ -90,10 +97,13 @@ class NextIterationPass : public GraphPass { /// @param [out] target_node /// @return Status /// - Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); + Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, + const std::string &batch_label, NodePtr &target_node); - // map - std::unordered_map loop_group_map_; + // map> + std::unordered_map> frame_enter_map_; + // map> + std::unordered_map> loop_group_map_; }; } // namespace ge #endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index b1fb3bbd..9ab74d70 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -44,8 +44,6 @@ using std::set; using std::string; using std::vector; -using std::map; -using std::queue; namespace ge { namespace multibatch { @@ -59,15 +57,10 @@ const int kDataInIndex = 0; const int kMergeDataOutIndex = 0; const int kStaticOutput = -1; const int kDivisionConst = 2; -const int32_t kOneInDataNode = 1; -const int32_t kFindNoMatch = 0; inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } -inline bool IsEnterType(const string &node_type) { return (node_type == ENTER) || (node_type == REFENTER); } -const set unchange_types({CONSTANT, CONSTANTOP, ENTER, REFENTER}); - inline bool IsGetNextType(const NodePtr &node) { std::string original_type; GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, @@ -225,6 +218,12 @@ Status MultiBatchGraphCopyer::CopyGraph() { return ret; } + ret = InsertIdentityAfterSwitchN(); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to insert identity nodes after switchn node."); + return INTERNAL_ERROR; + } + GELOGI("Begin to remove useless nodes by prune pass after copy process"); PrunePass prune_pass; ret = prune_pass.Run(graph_); @@ -241,18 +240,6 @@ Status MultiBatchGraphCopyer::Init() { return ret; } - ret = RelinkConstCtrlEdge(); - if (ret != SUCCESS) { - GELOGE(FAILED, "Relink const's control edge failed."); - return FAILED; - } - - ret = ExtractUnchangedStructureOutofCycle(); - if (ret != SUCCESS) { - GELOGE(FAILED, "Extract unchanged structure out of cycle failed."); - return FAILED; - } - for (auto &node : graph_->GetAllNodes()) { origin_all_nodes_.emplace_back(node); if (IsDataLikeType(node->GetType())) { @@ -265,281 +252,6 @@ Status MultiBatchGraphCopyer::Init() { return SUCCESS; } -Status MultiBatchGraphCopyer::RelinkConstCtrlEdge() { - for (auto &node : graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(node); - if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) { - if (node->GetOutDataNodes().empty()) { - continue; - } - if (!node->GetInControlNodes().empty()) { - auto in_ctrl_nodes = node->GetInControlNodes(); - auto out_nodes = node->GetOutAllNodes(); - bool has_merge = false; - for (const auto &out_node : out_nodes) { - GE_CHECK_NOTNULL(out_node); - if (out_node->GetType() == MERGE || out_node->GetType() == REFMERGE) { - has_merge = true; - break; - } - } - if (has_merge) { - continue; - } - auto in_ctrl_anchor = node->GetInControlAnchor(); - GE_CHECK_NOTNULL(in_ctrl_anchor); - in_ctrl_anchor->UnlinkAll(); - for (auto &in_ctrl_node : in_ctrl_nodes) { - auto out_ctrl_anchor_of_in_ctrl_node = in_ctrl_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(out_ctrl_anchor_of_in_ctrl_node); - for (auto &out_node : out_nodes) { - if (IsEnterType(out_node->GetType())) { - continue; - } - if (!out_ctrl_anchor_of_in_ctrl_node->IsLinkedWith(out_node->GetInControlAnchor())) { - GE_CHK_STATUS_RET(out_ctrl_anchor_of_in_ctrl_node->LinkTo(out_node->GetInControlAnchor())) - } - } - } - } - auto out_ctrl_anchor = node->GetOutControlAnchor(); - if (out_ctrl_anchor != nullptr) { - out_ctrl_anchor->UnlinkAll(); - } - } - } - - return SUCCESS; -} - -Status MultiBatchGraphCopyer::ExtractUnchangedStructureOutofCycle() { - map> frame_enter; - if (GetEnterNodesGroupByFrame(frame_enter) != SUCCESS) { - GELOGE(FAILED, "Get enter nodes grouped by frame_name failed."); - return FAILED; - } - - queue nodes_to_extract; - if (GetNodeNeedExtract(frame_enter, nodes_to_extract) != SUCCESS) { - GELOGE(FAILED, "Get nodes needed to extract failed."); - return FAILED; - } - - while (!nodes_to_extract.empty()) { - auto node = nodes_to_extract.front(); - nodes_to_extract.pop(); - OpDescPtr enter_desc = nullptr; - if (MoveInEntersInDataAnchorDown(node, enter_desc) != SUCCESS) { - GELOGE(FAILED, "Move in enter nodes' in data anchors down of %s failed.", node->GetName().c_str()); - return FAILED; - } - set out_nodes; - if (InsertEnterAfterNode(node, enter_desc, out_nodes) != SUCCESS) { - GELOGE(FAILED, "Insert enter node after %s failed.", node->GetName().c_str()); - return FAILED; - } - - if (MoveCtrlEdgeToOutNodes(node, out_nodes) != SUCCESS) { - GELOGE(FAILED, "Move %s's control edge to out nodes failed.", node->GetName().c_str()); - return FAILED; - } - - for (auto &out_node : out_nodes) { - GE_CHECK_NOTNULL(out_node); - if (AllInDataNodesUnchangeAndNoMergeOut(out_node)) { - nodes_to_extract.push(out_node); - } - } - } - - if (DeleteEnterWithoutDataOut() != SUCCESS) { - GELOGE(FAILED, "Delete enter node without out data nodes failed."); - return FAILED; - } - - return SUCCESS; -} - -Status MultiBatchGraphCopyer::GetEnterNodesGroupByFrame(map> &frame_enter) { - for (auto &node : graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(node); - if (IsEnterType(node->GetType())) { - if (!node->GetInControlNodes().empty() || !node->GetOutControlNodes().empty()) { - continue; - } - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - string frame_name; - if (!AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { - GELOGE(FAILED, "Get attr frame_name of enter[%] failed.", node->GetName().c_str()); - return FAILED; - } - frame_enter[frame_name].emplace_back(node); - } - } - - return SUCCESS; -} - -Status MultiBatchGraphCopyer::GetNodeNeedExtract(const map> &frame_enter, - queue &nodes_to_extract) { - for (const auto &one_group : frame_enter) { - auto enters = one_group.second; - for (const auto &enter : enters) { - auto out_data_nodes = enter->GetOutDataNodes(); - for (const auto &out_data_node : out_data_nodes) { - GE_CHECK_NOTNULL(out_data_node); - if (AllInDataNodesUnchangeAndNoMergeOut(out_data_node)) { - nodes_to_extract.push(out_data_node); - } - } - } - } - - return SUCCESS; -} - -bool MultiBatchGraphCopyer::AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node) { - auto out_data_nodes = node->GetOutDataNodes(); - for (const auto &out_data_node : out_data_nodes) { - if (out_data_node == nullptr) { - return false; - } - - if (out_data_node->GetType() == MERGE || out_data_node->GetType() == REFMERGE) { - return false; - } - } - - auto in_data_nodes = node->GetInDataNodes(); - if (in_data_nodes.size() == kOneInDataNode) { - return true; - } - - for (const auto &in_data_node : in_data_nodes) { - if (in_data_node == nullptr) { - return false; - } - if (unchange_types.count(in_data_node->GetType()) == kFindNoMatch) { - return false; - } - } - - return true; -} - -Status MultiBatchGraphCopyer::MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc) { - auto in_data_anchors = node->GetAllInDataAnchors(); - for (auto &in_data_anchor : in_data_anchors) { - auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_data_anchor); - auto peer_in_data_node = peer_out_data_anchor->GetOwnerNode(); - if (IsEnterType(peer_in_data_node->GetType())) { - GE_CHK_STATUS_RET(peer_out_data_anchor->Unlink(in_data_anchor)) - GELOGD("Unlink data edge from %s to %s.", peer_in_data_node->GetName().c_str(), node->GetName().c_str()); - auto enter_in_data_anchors = peer_in_data_node->GetAllInDataAnchors(); - for (auto &enter_in_data_anchor : enter_in_data_anchors) { - auto peer_out_data_anchor_of_enter = enter_in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_data_anchor_of_enter); - if (peer_out_data_anchor_of_enter->IsLinkedWith(in_data_anchor)) { - continue; - } - GE_CHK_STATUS_RET(peer_out_data_anchor_of_enter->LinkTo(in_data_anchor)) - GELOGD("Relink data edge from %s to %s.", peer_out_data_anchor_of_enter->GetOwnerNode()->GetName().c_str(), - node->GetName().c_str()); - } - enter_desc = peer_in_data_node->GetOpDesc(); - GE_CHECK_NOTNULL(enter_desc); - } - } - - return SUCCESS; -} - -Status MultiBatchGraphCopyer::InsertEnterAfterNode(NodePtr &node, const OpDescPtr ©_desc, set &out_nodes) { - if (copy_desc == nullptr) { - return SUCCESS; - } - map>> outanchors_inanchors_nodes; - auto out_data_anchors = node->GetAllOutDataAnchors(); - for (auto &out_data_anchor : out_data_anchors) { - auto peer_in_data_anchors = out_data_anchor->GetPeerInDataAnchors(); - for (auto peer_in_data_anchor : peer_in_data_anchors) { - GE_CHECK_NOTNULL(peer_in_data_anchor); - auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); - out_nodes.emplace(peer_in_data_node); - outanchors_inanchors_nodes[out_data_anchor].emplace_back(std::make_pair(peer_in_data_anchor, peer_in_data_node)); - } - } - - int32_t i = 0; - auto node_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(node_desc); - // Insert one enter node after node's per out data anchor - for (auto &outanchor_inanchors_nodes : outanchors_inanchors_nodes) { - string name = node->GetName() + "_" + ENTER + "_" + std::to_string(i++); - GELOGD("Create Enter op %s after %s.", name.c_str(), node->GetName().c_str()); - auto enter_desc = AttrUtils::CopyOpDesc(copy_desc); - enter_desc->SetName(name); - GE_CHK_STATUS_RET( - enter_desc->UpdateInputDesc("x", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx()))) - GE_CHK_STATUS_RET( - enter_desc->UpdateOutputDesc("y", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx()))) - auto enter_node = graph_->AddNode(enter_desc); - GE_CHECK_NOTNULL(enter_node); - GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->LinkTo(enter_node->GetInDataAnchor(kDataInIndex))) - GE_CHECK_NOTNULL(enter_node->GetOutDataAnchor(kDataInIndex)); - for (auto &inanchor_node : outanchor_inanchors_nodes.second) { - GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->Unlink(inanchor_node.first)) - GE_CHK_STATUS_RET(enter_node->GetOutDataAnchor(kDataInIndex)->LinkTo(inanchor_node.first)) - GELOGD("Unlink from %s to %s, link from %s to %s then to %s.", node->GetName().c_str(), - inanchor_node.second->GetName().c_str(), node->GetName().c_str(), enter_node->GetName().c_str(), - inanchor_node.second->GetName().c_str()); - } - } - - return SUCCESS; -} - -// Move node's in control edges to out data nodes -Status MultiBatchGraphCopyer::MoveCtrlEdgeToOutNodes(NodePtr &node, set &out_nodes) { - auto in_ctrl_anchor = node->GetInControlAnchor(); - GE_CHECK_NOTNULL(in_ctrl_anchor); - auto peer_out_ctrl_anchors = in_ctrl_anchor->GetPeerOutControlAnchors(); - for (auto &peer_out_ctrl_anchor : peer_out_ctrl_anchors) { - GE_CHK_STATUS_RET(peer_out_ctrl_anchor->Unlink(in_ctrl_anchor)) - GELOGD("Unlink control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - node->GetName().c_str()); - for (auto &out_node : out_nodes) { - auto in_ctrl_anchor_of_out_node = out_node->GetInControlAnchor(); - GE_CHECK_NOTNULL(in_ctrl_anchor_of_out_node); - if (!peer_out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor_of_out_node)) { - GE_CHK_STATUS_RET(peer_out_ctrl_anchor->LinkTo(in_ctrl_anchor_of_out_node)) - GELOGD("Link control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - out_node->GetName().c_str()); - } - } - } - - return SUCCESS; -} - -Status MultiBatchGraphCopyer::DeleteEnterWithoutDataOut() { - for (auto &node : graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(node); - if (IsEnterType(node->GetType())) { - auto out_nodes = node->GetOutAllNodes(); - if (out_nodes.empty()) { - GELOGD("Delete enter node: %s which has no output.", node->GetName().c_str()); - GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {})) - GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node)) - } - } - } - - return SUCCESS; -} - void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) { auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(), @@ -585,9 +297,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { LabelStatusForGetNextSink(data); } } - - map> frame_enters; - InitStatus(frame_enters); bool changed = true; // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch while (changed) { @@ -597,13 +306,12 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { if (iter != origin_nodes_status_.end()) { continue; } - for (auto &in_node : node->GetInDataNodes()) { - if (origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end()) { - if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) { - origin_nodes_status_[node.get()] == kNodeInBatchBranch; - ResetEnterStatus(frame_enters, node); - changed = true; - } + for (auto &in_node : node->GetInAllNodes()) { + bool is_in_batch = origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() && + origin_nodes_status_[in_node.get()] == kNodeInBatchBranch; + if (is_in_batch) { + origin_nodes_status_[node.get()] = kNodeInBatchBranch; + changed = true; break; } } @@ -612,45 +320,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { return SUCCESS; } -void MultiBatchGraphCopyer::InitStatus(map> &frame_enters) { - for (const auto &node : origin_all_nodes_) { - if (!IsEnterType(node->GetType())) { - continue; - } - auto op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - string frame_name; - if (AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { - frame_enters[frame_name].emplace_back(node); - } - } - - for (const auto &data : origin_data_nodes_) { - auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); - if (!IsAllDimsPositive(data_shape.GetDims())) { - origin_nodes_status_[data.get()] = kNodeInBatchBranch; - } - } -} - -void MultiBatchGraphCopyer::ResetEnterStatus(map> &frame_enters, const NodePtr &node) { - if (!IsEnterType(node->GetType())) { - return; - } - - for (const auto &frame_enter : frame_enters) { - auto &enters = frame_enter.second; - if (std::find(enters.begin(), enters.end(), node) != enters.end()) { - for (const auto &enter : enters) { - origin_nodes_status_[enter.get()] = kNodeInBatchBranch; - } - break; - } - } -} - Status MultiBatchGraphCopyer::LabelStatus() { if (LabelInBatchBranchStatus() != SUCCESS) { GELOGE(PARAM_INVALID, "Failed to label no in batch branch"); @@ -1691,6 +1360,52 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { return SUCCESS; } +Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() { + for (auto &node : graph_->GetAllNodes()) { + if (node->GetType() != SWITCHN) { + continue; + } + auto switchn_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(switchn_desc); + size_t i = 0; + for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + auto out_node = in_data_anchor->GetOwnerNode(); + auto op_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if ((out_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { + GELOGD("No need to insert identity between %s and %s.", node->GetName().c_str(), out_node->GetName().c_str()); + continue; + } + + auto identity_desc = MakeShared(node->GetName() + "_identity_" + std::to_string(i), IDENTITY); + GE_CHECK_NOTNULL(identity_desc); + + string batch_label; + if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { + if (!AttrUtils::SetStr(identity_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { + GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", identity_desc->GetName().c_str()); + return FAILED; + } + } + + auto data_desc = switchn_desc->GetOutputDesc(i); + i++; + GE_CHK_STATUS_RET(identity_desc->AddInputDesc("x", data_desc)); + GE_CHK_STATUS_RET(identity_desc->AddOutputDesc("y", data_desc)); + + auto identity_node = graph_->AddNode(identity_desc); + GE_CHECK_NOTNULL(identity_node); + GE_CHK_STATUS_RET(out_data_anchor->LinkTo(identity_node->GetInDataAnchor(0))); + GE_CHECK_NOTNULL(identity_node->GetOutControlAnchor()); + GE_CHK_STATUS_RET(identity_node->GetOutControlAnchor()->LinkTo(out_node->GetInControlAnchor())); + } + } + } + + return SUCCESS; +} + Status ProcessMultiBatch(ComputeGraphPtr &graph) { const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE"); if (multi_batch_with_case != nullptr) { diff --git a/ge/graph/preprocess/multi_batch_copy_graph.h b/ge/graph/preprocess/multi_batch_copy_graph.h index d51c4c02..a0de4413 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.h +++ b/ge/graph/preprocess/multi_batch_copy_graph.h @@ -18,7 +18,6 @@ #include #include #include -#include #include "external/ge/ge_api_error_codes.h" @@ -65,26 +64,12 @@ class MultiBatchGraphCopyer { private: Status Init(); Status CheckArguments(); - Status RelinkConstCtrlEdge(); - - Status ExtractUnchangedStructureOutofCycle(); - Status GetEnterNodesGroupByFrame(std::map> &frame_enter); - Status GetNodeNeedExtract(const std::map> &frame_enter, - std::queue &nodes_to_extract); - bool AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node); - Status MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc); - Status InsertEnterAfterNode(NodePtr &node, const OpDescPtr &enter_desc, std::set &out_nodes); - Status MoveCtrlEdgeToOutNodes(NodePtr &node, std::set &out_nodes); - Status DeleteEnterWithoutDataOut(); // label status for origin_all_nodes_ Status LabelStatus(); Status LabelInBatchBranchStatus(); void LabelStatusForData(const NodePtr &data); void LabelStatusForGetNextSink(const NodePtr &data); - void InitStatus(std::map> &frame_enters); - void ResetEnterStatus(std::map> &frame_enters, const NodePtr &node); - // add nodes functions Status CreateNewNodes(); @@ -96,6 +81,7 @@ class MultiBatchGraphCopyer { Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index, std::vector> &dynamic_out_to_switchn); + Status InsertIdentityAfterSwitchN(); Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index); Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index);