Browse Source

!578 回退 'Pull Request !536 : Decrease transformer's om size in dynamic dims scenario'

From: @wqtshg
Reviewed-by: @xchu42,@ji_chen
Signed-off-by: @ji_chen
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
22a17b3173
21 changed files with 295 additions and 664 deletions
  1. +0
    -1
      ge/CMakeLists.txt
  2. +0
    -1
      ge/ge_inference.mk
  3. +0
    -1
      ge/ge_runner.mk
  4. +0
    -1
      ge/graph/build/model_builder.cc
  5. +0
    -5
      ge/graph/manager/graph_manager.cc
  6. +23
    -5
      ge/graph/passes/attach_stream_label_pass.cc
  7. +3
    -1
      ge/graph/passes/attach_stream_label_pass.h
  8. +1
    -1
      ge/graph/passes/base_pass.cc
  9. +1
    -2
      ge/graph/passes/common_subexpression_elimination_pass.cc
  10. +0
    -55
      ge/graph/passes/const_pass.cc
  11. +0
    -29
      ge/graph/passes/const_pass.h
  12. +0
    -64
      ge/graph/passes/dimension_adjust_pass.cc
  13. +0
    -4
      ge/graph/passes/dimension_adjust_pass.h
  14. +7
    -41
      ge/graph/passes/enter_pass.cc
  15. +1
    -2
      ge/graph/passes/enter_pass.h
  16. +4
    -1
      ge/graph/passes/folding_pass.cc
  17. +10
    -0
      ge/graph/passes/merge_to_stream_merge_pass.cc
  18. +173
    -89
      ge/graph/passes/next_iteration_pass.cc
  19. +13
    -3
      ge/graph/passes/next_iteration_pass.h
  20. +58
    -343
      ge/graph/preprocess/multi_batch_copy_graph.cc
  21. +1
    -15
      ge/graph/preprocess/multi_batch_copy_graph.h

+ 0
- 1
ge/CMakeLists.txt View File

@@ -154,7 +154,6 @@ set(TRAIN_SRC_LIST
"graph/passes/compile_nodes_pass.cc"
"graph/passes/constant_folding_pass.cc"
"graph/passes/constant_fuse_same_pass.cc"
"graph/passes/control_trigger_pass.cc"
"graph/passes/dimension_adjust_pass.cc"
"graph/passes/dimension_compute_pass.cc"


+ 0
- 1
ge/ge_inference.mk View File

@@ -189,7 +189,6 @@ OMG_HOST_SRC_FILES := \
graph/passes/control_trigger_pass.cc \
graph/passes/cond_pass.cc \
graph/passes/cond_remove_pass.cc \
graph/passes/const_pass.cc \
graph/passes/for_pass.cc \
graph/passes/enter_pass.cc \
graph/passes/assign_pass.cc \


+ 0
- 1
ge/ge_runner.mk View File

@@ -123,7 +123,6 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/compile_nodes_pass.cc \
graph/passes/constant_folding_pass.cc \
graph/passes/constant_fuse_same_pass.cc \
graph/passes/const_pass.cc \
graph/passes/control_trigger_pass.cc \
graph/passes/dimension_adjust_pass.cc \
graph/passes/dimension_compute_pass.cc \


+ 0
- 1
ge/graph/build/model_builder.cc View File

@@ -224,7 +224,6 @@ Status ModelBuilder::AdjustConstWeightSize(const ge::NodePtr &node, size_t &mem_
GeTensorDesc &tensor_desc = weight->MutableTensorDesc();
size_t output_size = weight->GetData().size();
TensorUtils::SetDataOffset(tensor_desc, mem_offset);
GELOGD("Node: %s, weight size: %zu.", node->GetName().c_str(), output_size);
mem_offset += output_size;
}
return SUCCESS;


+ 0
- 5
ge/graph/manager/graph_manager.cc View File

@@ -56,7 +56,6 @@
#include "graph/passes/cond_remove_pass.h"
#include "graph/passes/constant_folding_pass.h"
#include "graph/passes/constant_fuse_same_pass.h"
#include "graph/passes/const_pass.cc"
#include "graph/passes/control_trigger_pass.h"
#include "graph/passes/ctrl_edge_transfer_pass.h"
#include "graph/passes/dimension_adjust_pass.h"
@@ -2138,7 +2137,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
TransposeTransDataPass transpose_transdata_pass;
TransOpSymmetryEliminationPass symmetry_elimination_pass;
DimensionComputePass dimension_compute_pass;
ConstPass const_pass;
names_to_passes.emplace_back("EnterPass", &enter_pass);
names_to_passes.emplace_back("AddNPass", &addn_pass);
names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination);
@@ -2152,7 +2150,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass);
names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass);
names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass);
names_to_passes.emplace_back("ConstPass", &const_pass);
GE_TIMESTAMP_START(names_to_passes);
ret = GEPass(compute_graph).Run(names_to_passes);
GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2");
@@ -2193,8 +2190,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass",
new (std::nothrow) VariableRefUselessControlOutDeletePass))
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass))
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::CommonSubexpressionEliminationPass",
new (std::nothrow) CommonSubexpressionEliminationPass));
if (options_.train_graph_flag) {
// Priority: The GlobalStepInsertPass should work before graph partitioner.
// Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory


+ 23
- 5
ge/graph/passes/attach_stream_label_pass.cc View File

@@ -18,8 +18,6 @@
#include "ge/ge_api_types.h"
#include "graph/common/omg_util.h"

using std::string;

namespace ge {
Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) {
GELOGD("AttachStreamLabelPass Enter.");
@@ -189,10 +187,21 @@ Status AttachStreamLabelPass::UpdateEnterNode() {
}

std::stack<NodePtr> enter_nodes;
std::string batch_label;
for (const auto &enter_node : pair.second) {
enter_nodes.emplace(enter_node);
std::string tmp_label;
(void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
if (!tmp_label.empty()) {
if (batch_label.empty()) {
batch_label = tmp_label;
} else if (batch_label != tmp_label) {
GELOGE(FAILED, "multi batch_label exist, label1=%s, label2=%s.", batch_label.c_str(), tmp_label.c_str());
return FAILED;
}
}
}
if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) {
if (UpdateLoopBranch(enter_nodes, active_label_list[0], batch_label) != SUCCESS) {
GELOGE(FAILED, "Update stream_label for loop_branch failed.");
return FAILED;
}
@@ -217,7 +226,10 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
}

for (const auto &enter_node : enter_nodes) {
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed.");
GE_CHECK_NOTNULL(enter_node->GetOpDesc());
if (enter_node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL)) {
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed.");
}
}
return SUCCESS;
}
@@ -229,7 +241,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
/// @param [in] batch_label
/// @return Status
///
Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const string &stream_label) {
Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label,
const std::string &batch_label) {
std::stack<NodePtr> nodes(enter_nodes);
NodePtr cur_node = nullptr;
while (!nodes.empty()) {
@@ -238,6 +251,11 @@ Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_
for (const NodePtr &out_node : cur_node->GetOutAllNodes()) {
OpDescPtr out_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(out_desc);
std::string tmp_label;
(void)AttrUtils::GetStr(out_desc, ATTR_NAME_BATCH_LABEL, tmp_label);
if (!tmp_label.empty() && (tmp_label != batch_label)) {
continue;
}
std::string out_type = out_desc->GetType();
bool need_skip =
out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) ||


+ 3
- 1
ge/graph/passes/attach_stream_label_pass.h View File

@@ -58,9 +58,11 @@ class AttachStreamLabelPass : public GraphPass {
/// @brief Update stream_label for loop_branch
/// @param [in] enter_nodes
/// @param [in] stream_label
/// @param [in] batch_label
/// @return Status
///
static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label);
static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label,
const std::string &batch_label);

///
/// @brief Update stream_label start with enter nodes


+ 1
- 1
ge/graph/passes/base_pass.cc View File

@@ -96,7 +96,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder
node->GetName().c_str(), node->GetType().c_str());
continue;
}
if (node_to_re_pass->IsAllInNodesSeen(nodes_seen) || node_to_re_pass->GetType() == ENTER) {
if (node_to_re_pass->IsAllInNodesSeen(nodes_seen)) {
GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str());
nodes_re_pass.insert(node_to_re_pass);
} else {


+ 1
- 2
ge/graph/passes/common_subexpression_elimination_pass.cc View File

@@ -58,8 +58,7 @@ std::string GetCseKey(const NodePtr &node) {
/// To avoid delete wrong nodes(e.g. stateful nodes),
/// only nodes have folding kernel will be considered for the CSE process
bool IsNodeSupportCse(const NodePtr &node) {
if (HostCpuEngine::CheckSupported(NodeUtils::GetNodeType(*node)) || node->GetType() == CONSTANT ||
node->GetType() == CONSTANTOP) {
if (HostCpuEngine::CheckSupported(NodeUtils::GetNodeType(*node))) {
return true;
}
return folding_pass::GetKernelByType(node) != nullptr;


+ 0
- 55
ge/graph/passes/const_pass.cc View File

@@ -1,55 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/const_pass.h"

#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"

namespace ge {
Status ConstPass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);

if ((node->GetType() != CONSTANT) && (node->GetType() != CONSTANTOP)) {
return SUCCESS;
}
GELOGD("ConstPass running, node: %s.", node->GetName().c_str());

// const has no control input
if (node->GetInControlNodes().empty()) {
auto out_ctrl_anchor = node->GetOutControlAnchor();
if (out_ctrl_anchor != nullptr) {
GELOGD("Node: %s unlink all out control edge.", node->GetName().c_str());
out_ctrl_anchor->UnlinkAll();
}

if (node->GetOutAllNodes().empty()) {
// it is an isolated const, just remove it.
GELOGD("Delete isolated const: %s.", node->GetName().c_str());
auto graph = node->GetOwnerComputeGraph();
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove const %s failed.", node->GetName().c_str());
return FAILED;
}
AddNodeDeleted(node);
}
}

return SUCCESS;
}
} // namespace ge

+ 0
- 29
ge/graph/passes/const_pass.h View File

@@ -1,29 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_CONST_PASS_H_
#define GE_GRAPH_PASSES_CONST_PASS_H_

#include "graph/passes/base_pass.h"

namespace ge {
class ConstPass : public BaseNodePass {
public:
Status Run(NodePtr &node) override;
};
} // namespace ge

#endif // GE_GRAPH_PASSES_CONST_PASS_H_

+ 0
- 64
ge/graph/passes/dimension_adjust_pass.cc View File

@@ -80,71 +80,7 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) {
}
}

ret = DealWithInNodes(node);
if (ret != SUCCESS) {
GELOGE(ret, "DealWithInNodes of %s failed.", node->GetName().c_str());
return ret;
}

std::vector<int> data_relink_io_map = {kDataInputIndex};
return IsolateAndDeleteNode(node, data_relink_io_map);
}

Status DimensionAdjustPass::DealWithInNodes(NodePtr &node) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
auto graph = node->GetOwnerComputeGraph();
auto in_data_anchors = node->GetAllInDataAnchors();
for (auto &in_data_anchor : in_data_anchors) {
if (in_data_anchor == nullptr) {
continue;
}
auto in_node_anchor = in_data_anchor->GetPeerOutAnchor();
if (in_node_anchor == nullptr) {
continue;
}
auto in_node = in_node_anchor->GetOwnerNode();
if (in_node->GetType() == SWITCHN) {
GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str());
auto identity_name = node->GetName() + "_ctrl_identity_" + std::to_string(in_data_anchor->GetIdx());
auto identity =
AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph);
GE_CHECK_NOTNULL(identity);
GELOGI("Create new identity node[%s] success.", identity->GetName().c_str());
GE_CHK_STATUS_RET(GraphUtils::AddEdge(in_node_anchor, identity->GetInDataAnchor(0)))
GE_CHECK_NOTNULL(identity->GetOutControlAnchor());
if (identity->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor())) {
continue;
}
GE_CHK_STATUS_RET(GraphUtils::AddEdge(identity->GetOutControlAnchor(), node->GetInControlAnchor()))
}
}

return SUCCESS;
}

NodePtr DimensionAdjustPass::AddIdentityNodeToGraph(const string &name, const GeTensorDesc &tensor,
ComputeGraphPtr &graph) {
if (graph == nullptr) {
GELOGE(INTERNAL_ERROR, "Comput graph ptr is null in creating identity node.");
return nullptr;
}

OpDescPtr desc = MakeShared<OpDesc>("", "");
if (desc == nullptr) {
GELOGE(MEMALLOC_FAILED, "Failed to create op desc.");
return nullptr;
}

desc->SetName(name);
desc->SetType(IDENTITY);
auto ret = desc->AddInputDesc(tensor);
auto ret2 = desc->AddOutputDesc(tensor);
if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) {
GELOGE(INTERNAL_ERROR, "Failed to add input/output desc in creating identity.");
return nullptr;
}

return graph->AddNodeFront(desc);
}
} // namespace ge

+ 0
- 4
ge/graph/passes/dimension_adjust_pass.h View File

@@ -34,10 +34,6 @@ namespace ge {
class DimensionAdjustPass : public BaseNodePass {
public:
Status Run(ge::NodePtr &node) override;

private:
Status DealWithInNodes(ge::NodePtr &node);
NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tensor, ComputeGraphPtr &graph);
};
} // namespace ge



+ 7
- 41
ge/graph/passes/enter_pass.cc View File

@@ -23,7 +23,6 @@

namespace {
const size_t kOutNodesNum = 1;
const size_t kInCtrlNodesNum = 1;
}

namespace ge {
@@ -56,7 +55,6 @@ Status EnterPass::Run(NodePtr &node) {
if (out_ctrl_node == nullptr) {
continue;
}
GELOGD("Remove control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str());
if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(),
out_ctrl_node->GetName().c_str());
@@ -64,12 +62,8 @@ Status EnterPass::Run(NodePtr &node) {
}
}
} else {
if (OptimizeEnterWithOnlyOutData(node, in_node) != SUCCESS) {
GELOGE(FAILED, "Optimize enter node[%s] with only out data node failed.", node->GetName().c_str());
return FAILED;
}
if (UnlinkCtrlEdgeBeforeConst(node) != SUCCESS) {
GELOGE(FAILED, "Unlink control edge before const of node[%s]'s out nodes failed.", node->GetName().c_str());
if (OptimizeEnter(node, in_node) != SUCCESS) {
GELOGE(FAILED, "Optimize enter node[%s] failed.", node->GetName().c_str());
return FAILED;
}
}
@@ -78,7 +72,7 @@ Status EnterPass::Run(NodePtr &node) {
return SUCCESS;
}

Status EnterPass::OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node) {
Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) {
if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) {
return SUCCESS;
}
@@ -89,45 +83,17 @@ Status EnterPass::OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node)
}

GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0)))
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0)));
const auto &out_data_anchor = node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(out_data_anchor);
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor))
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor))
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor));
}
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node))
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node));
AddNodeDeleted(node);
AddRePassNodesWithInOut(in_node);

return SUCCESS;
}

Status EnterPass::UnlinkCtrlEdgeBeforeConst(NodePtr &node) {
auto out_ctrl_nodes = node->GetOutControlNodes();
if (out_ctrl_nodes.empty()) {
return SUCCESS;
}
auto out_ctrl_anchor = node->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_ctrl_anchor);

for (auto &out_ctrl_node : out_ctrl_nodes) {
GE_CHECK_NOTNULL(out_ctrl_node);
if ((out_ctrl_node->GetType() != CONSTANT) && (out_ctrl_node->GetType() != CONSTANTOP)) {
continue;
}
auto in_ctrl_nodes = out_ctrl_node->GetInControlNodes();
if (in_ctrl_nodes.size() != kInCtrlNodesNum) {
continue;
}
GE_CHK_STATUS_RET(out_ctrl_anchor->Unlink(out_ctrl_node->GetInControlAnchor()))
auto out_nodes_of_const = out_ctrl_node->GetOutAllNodes();
for (auto &out_node_of_const : out_nodes_of_const) {
if (!out_ctrl_anchor->IsLinkedWith(out_node_of_const->GetInControlAnchor())) {
GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(out_node_of_const->GetInControlAnchor()))
}
}
}
return SUCCESS;
}
} // namespace ge

+ 1
- 2
ge/graph/passes/enter_pass.h View File

@@ -25,8 +25,7 @@ class EnterPass : public BaseNodePass {
Status Run(NodePtr &node) override;

private:
Status OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node);
Status UnlinkCtrlEdgeBeforeConst(NodePtr &node);
Status OptimizeEnter(NodePtr &node, NodePtr &in_node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_ENTER_PASS_H_

+ 4
- 1
ge/graph/passes/folding_pass.cc View File

@@ -173,7 +173,10 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) {
continue;
}
auto in_node = in_node_anchor->GetOwnerNode();
if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH) || (in_node->GetType() == SWITCHN)) {
if (in_node == nullptr) {
continue;
}
if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) {
GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str());
auto ret = in_node_anchor->Unlink(in_data_anchor);
if (ret != SUCCESS) {


+ 10
- 0
ge/graph/passes/merge_to_stream_merge_pass.cc View File

@@ -89,6 +89,16 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co
GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed");
}

if (merge_op_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) {
string batch_label;
(void)AttrUtils::GetStr(merge_op_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (!batch_label.empty()) {
auto stream_merge_desc = stream_merge->GetOpDesc();
GE_CHECK_NOTNULL(stream_merge_desc);
(void)AttrUtils::SetStr(stream_merge_desc, ATTR_NAME_BATCH_LABEL, batch_label);
}
}

return AddActiveNodes(graph, stream_merge);
}



+ 173
- 89
ge/graph/passes/next_iteration_pass.cc View File

@@ -19,8 +19,6 @@
#include "common/ge/ge_util.h"
#include "graph/common/omg_util.h"

using std::string;

namespace ge {
Status NextIterationPass::Run(ComputeGraphPtr graph) {
GELOGD("NextIterationPass Enter");
@@ -37,6 +35,10 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) {
return INTERNAL_ERROR;
}
}
if (GroupWithNoBatch(graph) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Group enter_nodes failed without batch_label attr.");
return INTERNAL_ERROR;
}

if (FindWhileGroups() != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Find while groups failed.");
@@ -71,22 +73,75 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) {
return FAILED;
}

string batch_label;
if (ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
frame_name += batch_label;
std::string batch_label;
(void)ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (batch_label.empty()) {
auto frame_iter = frame_enter_map_.find(frame_name);
if (frame_iter == frame_enter_map_.end()) {
std::vector<NodePtr> enter_nodes;
enter_nodes.emplace_back(enter_node);
frame_enter_map_[frame_name] = enter_nodes;
} else {
frame_iter->second.emplace_back(enter_node);
}
return SUCCESS;
}

auto iter = loop_group_map_.find(frame_name);
if (iter == loop_group_map_.end()) {
auto group_iter = loop_group_map_.find(frame_name);
if (group_iter == loop_group_map_.end()) {
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
if (loop_group == nullptr) {
GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
return FAILED;
}
loop_group->enter_nodes.emplace_back(enter_node);
loop_group_map_[frame_name] = loop_group;
loop_group_map_[frame_name][batch_label] = loop_group;
} else {
iter->second->enter_nodes.emplace_back(enter_node);
auto batch_iter = group_iter->second.find(batch_label);
if (batch_iter == group_iter->second.end()) {
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
if (loop_group == nullptr) {
GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
return FAILED;
}
loop_group->enter_nodes.emplace_back(enter_node);
group_iter->second[batch_label] = loop_group;
} else {
batch_iter->second->enter_nodes.emplace_back(enter_node);
}
}

return SUCCESS;
}

///
/// @brief Group Enter nodes without batch_label attr
/// @param [in] compute_graph
/// @return Status
///
Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) {
if (frame_enter_map_.empty()) {
GELOGI("All enter nodes in graph %s has batch_label attr.", graph->GetName().c_str());
return SUCCESS;
}
for (const auto &item : frame_enter_map_) {
const std::string &frame_name = item.first;
auto iter = loop_group_map_.find(frame_name);
if (iter == loop_group_map_.end()) {
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
if (loop_group == nullptr) {
GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
return FAILED;
}
loop_group->enter_nodes = item.second;
loop_group_map_[frame_name][""] = loop_group;
} else {
for (auto &batch_item : iter->second) {
for (const auto &enter_node : item.second) {
batch_item.second->enter_nodes.emplace_back(enter_node);
}
}
}
}

return SUCCESS;
@@ -99,39 +154,55 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) {
Status NextIterationPass::FindWhileGroups() {
for (const auto &loop_group_iter : loop_group_map_) {
const std::string &frame_name = loop_group_iter.first;
for (const auto &enter_node : loop_group_iter.second->enter_nodes) {
for (const auto &out_node : enter_node->GetOutAllNodes()) {
const string &type = out_node->GetType();
if ((type != MERGE) && (type != REFMERGE)) {
continue;
}

NodePtr next_node = nullptr;
if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s", frame_name.c_str());
return INTERNAL_ERROR;
}
loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node));

NodePtr switch_node = nullptr;
if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str());
return INTERNAL_ERROR;
}
if (switch_node == nullptr) {
continue;
}

NodePtr loop_cond = nullptr;
if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str());
return INTERNAL_ERROR;
}
if (loop_group_iter.second->loop_cond == nullptr) {
loop_group_iter.second->loop_cond = loop_cond;
} else if (loop_group_iter.second->loop_cond != loop_cond) {
GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str());
return FAILED;
for (const auto &batch_iter : loop_group_iter.second) {
const std::string &batch_label = batch_iter.first;
for (const auto &enter_node : batch_iter.second->enter_nodes) {
for (const auto &out_node : enter_node->GetOutAllNodes()) {
GELOGI("Find while_group for enter_node %s, frame_name:%s, batch_label:%s.", enter_node->GetName().c_str(),
frame_name.c_str(), batch_label.c_str());
if ((out_node->GetType() != MERGE) && (out_node->GetType() != REFMERGE)) {
continue;
}
std::string tmp_label;
GE_CHECK_NOTNULL(out_node->GetOpDesc());
(void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label));
if (need_skip) {
continue;
}

NodePtr next_node = nullptr;
if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR,
"Get NextIteration node failed: inputs of Merge should be Enter/NextIteration, current_Merge=%s",
out_node->GetName().c_str());
return INTERNAL_ERROR;
}
batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node));

NodePtr switch_node = nullptr;
if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get Switch node failed: output of Merge should be Switch, current_Merge=%s",
out_node->GetName().c_str());
return INTERNAL_ERROR;
}
if (switch_node == nullptr) {
continue;
}

NodePtr loop_cond = nullptr;
if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) {
GELOGE(INTERNAL_ERROR,
"Get LoopCond node failed: pred input of Switch should be LoopCond, current_Switch=%s",
switch_node->GetName().c_str());
return INTERNAL_ERROR;
}
if (batch_iter.second->loop_cond == nullptr) {
batch_iter.second->loop_cond = loop_cond;
} else if (batch_iter.second->loop_cond != loop_cond) {
GELOGE(FAILED, "Multi LoopCond nodes exist.");
return FAILED;
}
}
}
}
@@ -152,17 +223,19 @@ bool NextIterationPass::VerifyWhileGroup() {
GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty.");
return false;
}
if (loop_group_iter.second->loop_cond == nullptr) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str());
return false;
}

for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) {
if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.",
frame_name.c_str());
for (const auto &batch_iter : loop_group_iter.second) {
if (batch_iter.second->loop_cond == nullptr) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str());
return false;
}

for (const auto &pair_iter : batch_iter.second->merge_next_pairs) {
if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.",
frame_name.c_str());
return false;
}
}
}
}

@@ -176,53 +249,56 @@ bool NextIterationPass::VerifyWhileGroup() {
///
Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
for (const auto &loop_cond_iter : loop_group_map_) {
const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName();
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());

// Create Active node, Enter->Active->Merge, NextIteration->Active->Merge
NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE);
NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE);
if ((enter_active == nullptr) || (next_active == nullptr)) {
GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str());
return INTERNAL_ERROR;
}

for (const auto &enter_node : loop_cond_iter.second->enter_nodes) {
// Enter --> Active
if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge from %s to %s failed.", enter_node->GetName().c_str(),
enter_active->GetName().c_str());
for (const auto &batch_iter : loop_cond_iter.second) {
const std::string &cond_name = batch_iter.second->loop_cond->GetName();
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());

// Create Active node, Enter->Active->Merge, NextIteration->Active->Merge
NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE);
NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE);
if ((enter_active == nullptr) || (next_active == nullptr)) {
GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str());
return INTERNAL_ERROR;
}
}

for (const auto &pair : loop_cond_iter.second->merge_next_pairs) {
NodePtr merge_node = pair.first;
NodePtr next_node = pair.second;
// Active --> Merge
if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
for (const auto &enter_node : batch_iter.second->enter_nodes) {
// Enter --> Active
if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) !=
GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
}
}

// NextIteration --> Active
if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
for (const auto &pair : batch_iter.second->merge_next_pairs) {
NodePtr merge_node = pair.first;
NodePtr next_node = pair.second;
// Active --> Merge
if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) !=
GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
}

// NextIteration --> Active
if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
}

// break link between NextIteration and Merge
if (BreakNextIteration(next_node, merge_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Break NextIteration failed");
return INTERNAL_ERROR;
}
}

// break link between NextIteration and Merge
if (BreakNextIteration(next_node, merge_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Break NextIteration failed");
if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
(SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) {
GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed.");
return INTERNAL_ERROR;
}
}

if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
(SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) {
GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed.");
return INTERNAL_ERROR;
}
}

return SUCCESS;
@@ -289,11 +365,12 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr &
/// @param [in] node
/// @param [in] target_type
/// @param [in] is_input
/// @param [in] batch_label
/// @param [out] target_node
/// @return Status
///
Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input,
NodePtr &target_node) {
const std::string &batch_label, NodePtr &target_node) {
if (node == nullptr) {
GELOGE(PARAM_INVALID, "node is null.");
return PARAM_INVALID;
@@ -310,6 +387,12 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string
}

for (const auto &tmp_node : nodes) {
std::string tmp_label;
(void)AttrUtils::GetStr(tmp_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label));
if (need_skip) {
continue;
}
const std::string type = tmp_node->GetType();
if ((target_type == LOOPCOND) && (type == target_type)) {
target_node = tmp_node;
@@ -332,6 +415,7 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string
/// @return SUCCESS
///
Status NextIterationPass::ClearStatus() {
frame_enter_map_.clear();
loop_group_map_.clear();
return SUCCESS;
}


+ 13
- 3
ge/graph/passes/next_iteration_pass.h View File

@@ -46,6 +46,13 @@ class NextIterationPass : public GraphPass {
///
Status GroupEnterNode(const NodePtr &enter_node);

///
/// @brief Group Enter nodes without batch_label attr
/// @param [in] compute_graph
/// @return Status
///
Status GroupWithNoBatch(const ComputeGraphPtr &graph);

///
/// @brief Find while groups
/// @return Status
@@ -90,10 +97,13 @@ class NextIterationPass : public GraphPass {
/// @param [out] target_node
/// @return Status
///
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node);
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input,
const std::string &batch_label, NodePtr &target_node);

// map<frame_name, LoopCondGroup>
std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_;
// map<frame_name, vector<enter_node>>
std::unordered_map<std::string, std::vector<NodePtr>> frame_enter_map_;
// map<frame_name, map<batch_label, LoopCondGroup>>
std::unordered_map<std::string, std::unordered_map<std::string, LoopCondGroupPtr>> loop_group_map_;
};
} // namespace ge
#endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_

+ 58
- 343
ge/graph/preprocess/multi_batch_copy_graph.cc View File

@@ -44,8 +44,6 @@
using std::set;
using std::string;
using std::vector;
using std::map;
using std::queue;

namespace ge {
namespace multibatch {
@@ -59,15 +57,10 @@ const int kDataInIndex = 0;
const int kMergeDataOutIndex = 0;
const int kStaticOutput = -1;
const int kDivisionConst = 2;
const int32_t kOneInDataNode = 1;
const int32_t kFindNoMatch = 0;


inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); }

inline bool IsEnterType(const string &node_type) { return (node_type == ENTER) || (node_type == REFENTER); }
const set<string> unchange_types({CONSTANT, CONSTANTOP, ENTER, REFENTER});

inline bool IsGetNextType(const NodePtr &node) {
std::string original_type;
GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS,
@@ -225,6 +218,12 @@ Status MultiBatchGraphCopyer::CopyGraph() {
return ret;
}

ret = InsertIdentityAfterSwitchN();
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to insert identity nodes after switchn node.");
return INTERNAL_ERROR;
}

GELOGI("Begin to remove useless nodes by prune pass after copy process");
PrunePass prune_pass;
ret = prune_pass.Run(graph_);
@@ -241,18 +240,6 @@ Status MultiBatchGraphCopyer::Init() {
return ret;
}

ret = RelinkConstCtrlEdge();
if (ret != SUCCESS) {
GELOGE(FAILED, "Relink const's control edge failed.");
return FAILED;
}

ret = ExtractUnchangedStructureOutofCycle();
if (ret != SUCCESS) {
GELOGE(FAILED, "Extract unchanged structure out of cycle failed.");
return FAILED;
}

for (auto &node : graph_->GetAllNodes()) {
origin_all_nodes_.emplace_back(node);
if (IsDataLikeType(node->GetType())) {
@@ -265,281 +252,6 @@ Status MultiBatchGraphCopyer::Init() {
return SUCCESS;
}

Status MultiBatchGraphCopyer::RelinkConstCtrlEdge() {
for (auto &node : graph_->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) {
if (node->GetOutDataNodes().empty()) {
continue;
}
if (!node->GetInControlNodes().empty()) {
auto in_ctrl_nodes = node->GetInControlNodes();
auto out_nodes = node->GetOutAllNodes();
bool has_merge = false;
for (const auto &out_node : out_nodes) {
GE_CHECK_NOTNULL(out_node);
if (out_node->GetType() == MERGE || out_node->GetType() == REFMERGE) {
has_merge = true;
break;
}
}
if (has_merge) {
continue;
}
auto in_ctrl_anchor = node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor);
in_ctrl_anchor->UnlinkAll();
for (auto &in_ctrl_node : in_ctrl_nodes) {
auto out_ctrl_anchor_of_in_ctrl_node = in_ctrl_node->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_ctrl_anchor_of_in_ctrl_node);
for (auto &out_node : out_nodes) {
if (IsEnterType(out_node->GetType())) {
continue;
}
if (!out_ctrl_anchor_of_in_ctrl_node->IsLinkedWith(out_node->GetInControlAnchor())) {
GE_CHK_STATUS_RET(out_ctrl_anchor_of_in_ctrl_node->LinkTo(out_node->GetInControlAnchor()))
}
}
}
}
auto out_ctrl_anchor = node->GetOutControlAnchor();
if (out_ctrl_anchor != nullptr) {
out_ctrl_anchor->UnlinkAll();
}
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::ExtractUnchangedStructureOutofCycle() {
map<string, vector<NodePtr>> frame_enter;
if (GetEnterNodesGroupByFrame(frame_enter) != SUCCESS) {
GELOGE(FAILED, "Get enter nodes grouped by frame_name failed.");
return FAILED;
}

queue<NodePtr> nodes_to_extract;
if (GetNodeNeedExtract(frame_enter, nodes_to_extract) != SUCCESS) {
GELOGE(FAILED, "Get nodes needed to extract failed.");
return FAILED;
}

while (!nodes_to_extract.empty()) {
auto node = nodes_to_extract.front();
nodes_to_extract.pop();
OpDescPtr enter_desc = nullptr;
if (MoveInEntersInDataAnchorDown(node, enter_desc) != SUCCESS) {
GELOGE(FAILED, "Move in enter nodes' in data anchors down of %s failed.", node->GetName().c_str());
return FAILED;
}
set<NodePtr> out_nodes;
if (InsertEnterAfterNode(node, enter_desc, out_nodes) != SUCCESS) {
GELOGE(FAILED, "Insert enter node after %s failed.", node->GetName().c_str());
return FAILED;
}

if (MoveCtrlEdgeToOutNodes(node, out_nodes) != SUCCESS) {
GELOGE(FAILED, "Move %s's control edge to out nodes failed.", node->GetName().c_str());
return FAILED;
}

for (auto &out_node : out_nodes) {
GE_CHECK_NOTNULL(out_node);
if (AllInDataNodesUnchangeAndNoMergeOut(out_node)) {
nodes_to_extract.push(out_node);
}
}
}

if (DeleteEnterWithoutDataOut() != SUCCESS) {
GELOGE(FAILED, "Delete enter node without out data nodes failed.");
return FAILED;
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::GetEnterNodesGroupByFrame(map<string, vector<NodePtr>> &frame_enter) {
for (auto &node : graph_->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
if (IsEnterType(node->GetType())) {
if (!node->GetInControlNodes().empty() || !node->GetOutControlNodes().empty()) {
continue;
}
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
string frame_name;
if (!AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) {
GELOGE(FAILED, "Get attr frame_name of enter[%] failed.", node->GetName().c_str());
return FAILED;
}
frame_enter[frame_name].emplace_back(node);
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::GetNodeNeedExtract(const map<string, vector<NodePtr>> &frame_enter,
queue<NodePtr> &nodes_to_extract) {
for (const auto &one_group : frame_enter) {
auto enters = one_group.second;
for (const auto &enter : enters) {
auto out_data_nodes = enter->GetOutDataNodes();
for (const auto &out_data_node : out_data_nodes) {
GE_CHECK_NOTNULL(out_data_node);
if (AllInDataNodesUnchangeAndNoMergeOut(out_data_node)) {
nodes_to_extract.push(out_data_node);
}
}
}
}

return SUCCESS;
}

bool MultiBatchGraphCopyer::AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node) {
auto out_data_nodes = node->GetOutDataNodes();
for (const auto &out_data_node : out_data_nodes) {
if (out_data_node == nullptr) {
return false;
}

if (out_data_node->GetType() == MERGE || out_data_node->GetType() == REFMERGE) {
return false;
}
}

auto in_data_nodes = node->GetInDataNodes();
if (in_data_nodes.size() == kOneInDataNode) {
return true;
}

for (const auto &in_data_node : in_data_nodes) {
if (in_data_node == nullptr) {
return false;
}
if (unchange_types.count(in_data_node->GetType()) == kFindNoMatch) {
return false;
}
}

return true;
}

Status MultiBatchGraphCopyer::MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc) {
auto in_data_anchors = node->GetAllInDataAnchors();
for (auto &in_data_anchor : in_data_anchors) {
auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_data_anchor);
auto peer_in_data_node = peer_out_data_anchor->GetOwnerNode();
if (IsEnterType(peer_in_data_node->GetType())) {
GE_CHK_STATUS_RET(peer_out_data_anchor->Unlink(in_data_anchor))
GELOGD("Unlink data edge from %s to %s.", peer_in_data_node->GetName().c_str(), node->GetName().c_str());
auto enter_in_data_anchors = peer_in_data_node->GetAllInDataAnchors();
for (auto &enter_in_data_anchor : enter_in_data_anchors) {
auto peer_out_data_anchor_of_enter = enter_in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_data_anchor_of_enter);
if (peer_out_data_anchor_of_enter->IsLinkedWith(in_data_anchor)) {
continue;
}
GE_CHK_STATUS_RET(peer_out_data_anchor_of_enter->LinkTo(in_data_anchor))
GELOGD("Relink data edge from %s to %s.", peer_out_data_anchor_of_enter->GetOwnerNode()->GetName().c_str(),
node->GetName().c_str());
}
enter_desc = peer_in_data_node->GetOpDesc();
GE_CHECK_NOTNULL(enter_desc);
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::InsertEnterAfterNode(NodePtr &node, const OpDescPtr &copy_desc, set<NodePtr> &out_nodes) {
if (copy_desc == nullptr) {
return SUCCESS;
}
map<OutDataAnchorPtr, vector<std::pair<InDataAnchorPtr, NodePtr>>> outanchors_inanchors_nodes;
auto out_data_anchors = node->GetAllOutDataAnchors();
for (auto &out_data_anchor : out_data_anchors) {
auto peer_in_data_anchors = out_data_anchor->GetPeerInDataAnchors();
for (auto peer_in_data_anchor : peer_in_data_anchors) {
GE_CHECK_NOTNULL(peer_in_data_anchor);
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode();
out_nodes.emplace(peer_in_data_node);
outanchors_inanchors_nodes[out_data_anchor].emplace_back(std::make_pair(peer_in_data_anchor, peer_in_data_node));
}
}

int32_t i = 0;
auto node_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(node_desc);
// Insert one enter node after node's per out data anchor
for (auto &outanchor_inanchors_nodes : outanchors_inanchors_nodes) {
string name = node->GetName() + "_" + ENTER + "_" + std::to_string(i++);
GELOGD("Create Enter op %s after %s.", name.c_str(), node->GetName().c_str());
auto enter_desc = AttrUtils::CopyOpDesc(copy_desc);
enter_desc->SetName(name);
GE_CHK_STATUS_RET(
enter_desc->UpdateInputDesc("x", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx())))
GE_CHK_STATUS_RET(
enter_desc->UpdateOutputDesc("y", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx())))
auto enter_node = graph_->AddNode(enter_desc);
GE_CHECK_NOTNULL(enter_node);
GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->LinkTo(enter_node->GetInDataAnchor(kDataInIndex)))
GE_CHECK_NOTNULL(enter_node->GetOutDataAnchor(kDataInIndex));
for (auto &inanchor_node : outanchor_inanchors_nodes.second) {
GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->Unlink(inanchor_node.first))
GE_CHK_STATUS_RET(enter_node->GetOutDataAnchor(kDataInIndex)->LinkTo(inanchor_node.first))
GELOGD("Unlink from %s to %s, link from %s to %s then to %s.", node->GetName().c_str(),
inanchor_node.second->GetName().c_str(), node->GetName().c_str(), enter_node->GetName().c_str(),
inanchor_node.second->GetName().c_str());
}
}

return SUCCESS;
}

// Move node's in control edges to out data nodes
Status MultiBatchGraphCopyer::MoveCtrlEdgeToOutNodes(NodePtr &node, set<NodePtr> &out_nodes) {
auto in_ctrl_anchor = node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor);
auto peer_out_ctrl_anchors = in_ctrl_anchor->GetPeerOutControlAnchors();
for (auto &peer_out_ctrl_anchor : peer_out_ctrl_anchors) {
GE_CHK_STATUS_RET(peer_out_ctrl_anchor->Unlink(in_ctrl_anchor))
GELOGD("Unlink control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
node->GetName().c_str());
for (auto &out_node : out_nodes) {
auto in_ctrl_anchor_of_out_node = out_node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor_of_out_node);
if (!peer_out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor_of_out_node)) {
GE_CHK_STATUS_RET(peer_out_ctrl_anchor->LinkTo(in_ctrl_anchor_of_out_node))
GELOGD("Link control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
out_node->GetName().c_str());
}
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::DeleteEnterWithoutDataOut() {
for (auto &node : graph_->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
if (IsEnterType(node->GetType())) {
auto out_nodes = node->GetOutAllNodes();
if (out_nodes.empty()) {
GELOGD("Delete enter node: %s which has no output.", node->GetName().c_str());
GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {}))
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node))
}
}
}

return SUCCESS;
}

void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) {
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(),
@@ -585,9 +297,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() {
LabelStatusForGetNextSink(data);
}
}

map<string, vector<NodePtr>> frame_enters;
InitStatus(frame_enters);
bool changed = true;
// If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch
while (changed) {
@@ -597,13 +306,12 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() {
if (iter != origin_nodes_status_.end()) {
continue;
}
for (auto &in_node : node->GetInDataNodes()) {
if (origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end()) {
if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) {
origin_nodes_status_[node.get()] == kNodeInBatchBranch;
ResetEnterStatus(frame_enters, node);
changed = true;
}
for (auto &in_node : node->GetInAllNodes()) {
bool is_in_batch = origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() &&
origin_nodes_status_[in_node.get()] == kNodeInBatchBranch;
if (is_in_batch) {
origin_nodes_status_[node.get()] = kNodeInBatchBranch;
changed = true;
break;
}
}
@@ -612,45 +320,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() {
return SUCCESS;
}

void MultiBatchGraphCopyer::InitStatus(map<string, vector<NodePtr>> &frame_enters) {
for (const auto &node : origin_all_nodes_) {
if (!IsEnterType(node->GetType())) {
continue;
}
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
continue;
}
string frame_name;
if (AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) {
frame_enters[frame_name].emplace_back(node);
}
}

for (const auto &data : origin_data_nodes_) {
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
if (!IsAllDimsPositive(data_shape.GetDims())) {
origin_nodes_status_[data.get()] = kNodeInBatchBranch;
}
}
}

void MultiBatchGraphCopyer::ResetEnterStatus(map<string, vector<NodePtr>> &frame_enters, const NodePtr &node) {
if (!IsEnterType(node->GetType())) {
return;
}

for (const auto &frame_enter : frame_enters) {
auto &enters = frame_enter.second;
if (std::find(enters.begin(), enters.end(), node) != enters.end()) {
for (const auto &enter : enters) {
origin_nodes_status_[enter.get()] = kNodeInBatchBranch;
}
break;
}
}
}

Status MultiBatchGraphCopyer::LabelStatus() {
if (LabelInBatchBranchStatus() != SUCCESS) {
GELOGE(PARAM_INVALID, "Failed to label no in batch branch");
@@ -1691,6 +1360,52 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) {
return SUCCESS;
}

Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() {
for (auto &node : graph_->GetAllNodes()) {
if (node->GetType() != SWITCHN) {
continue;
}
auto switchn_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(switchn_desc);
size_t i = 0;
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
for (auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
auto out_node = in_data_anchor->GetOwnerNode();
auto op_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if ((out_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) {
GELOGD("No need to insert identity between %s and %s.", node->GetName().c_str(), out_node->GetName().c_str());
continue;
}

auto identity_desc = MakeShared<OpDesc>(node->GetName() + "_identity_" + std::to_string(i), IDENTITY);
GE_CHECK_NOTNULL(identity_desc);

string batch_label;
if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
if (!AttrUtils::SetStr(identity_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", identity_desc->GetName().c_str());
return FAILED;
}
}

auto data_desc = switchn_desc->GetOutputDesc(i);
i++;
GE_CHK_STATUS_RET(identity_desc->AddInputDesc("x", data_desc));
GE_CHK_STATUS_RET(identity_desc->AddOutputDesc("y", data_desc));

auto identity_node = graph_->AddNode(identity_desc);
GE_CHECK_NOTNULL(identity_node);
GE_CHK_STATUS_RET(out_data_anchor->LinkTo(identity_node->GetInDataAnchor(0)));
GE_CHECK_NOTNULL(identity_node->GetOutControlAnchor());
GE_CHK_STATUS_RET(identity_node->GetOutControlAnchor()->LinkTo(out_node->GetInControlAnchor()));
}
}
}

return SUCCESS;
}

Status ProcessMultiBatch(ComputeGraphPtr &graph) {
const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE");
if (multi_batch_with_case != nullptr) {


+ 1
- 15
ge/graph/preprocess/multi_batch_copy_graph.h View File

@@ -18,7 +18,6 @@
#include <map>
#include <queue>
#include <vector>
#include <set>

#include "external/ge/ge_api_error_codes.h"

@@ -65,26 +64,12 @@ class MultiBatchGraphCopyer {
private:
Status Init();
Status CheckArguments();
Status RelinkConstCtrlEdge();

Status ExtractUnchangedStructureOutofCycle();
Status GetEnterNodesGroupByFrame(std::map<std::string, std::vector<NodePtr>> &frame_enter);
Status GetNodeNeedExtract(const std::map<std::string, std::vector<NodePtr>> &frame_enter,
std::queue<NodePtr> &nodes_to_extract);
bool AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node);
Status MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc);
Status InsertEnterAfterNode(NodePtr &node, const OpDescPtr &enter_desc, std::set<NodePtr> &out_nodes);
Status MoveCtrlEdgeToOutNodes(NodePtr &node, std::set<NodePtr> &out_nodes);
Status DeleteEnterWithoutDataOut();

// label status for origin_all_nodes_
Status LabelStatus();
Status LabelInBatchBranchStatus();
void LabelStatusForData(const NodePtr &data);
void LabelStatusForGetNextSink(const NodePtr &data);
void InitStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters);
void ResetEnterStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters, const NodePtr &node);

// add nodes functions
Status CreateNewNodes();

@@ -96,6 +81,7 @@ class MultiBatchGraphCopyer {
Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index,
std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn);

Status InsertIdentityAfterSwitchN();
Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index);
Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index);



Loading…
Cancel
Save