@@ -154,7 +154,6 @@ set(TRAIN_SRC_LIST | |||
"graph/passes/compile_nodes_pass.cc" | |||
"graph/passes/constant_folding_pass.cc" | |||
"graph/passes/constant_fuse_same_pass.cc" | |||
"graph/passes/control_trigger_pass.cc" | |||
"graph/passes/dimension_adjust_pass.cc" | |||
"graph/passes/dimension_compute_pass.cc" | |||
@@ -189,7 +189,6 @@ OMG_HOST_SRC_FILES := \ | |||
graph/passes/control_trigger_pass.cc \ | |||
graph/passes/cond_pass.cc \ | |||
graph/passes/cond_remove_pass.cc \ | |||
graph/passes/const_pass.cc \ | |||
graph/passes/for_pass.cc \ | |||
graph/passes/enter_pass.cc \ | |||
graph/passes/assign_pass.cc \ | |||
@@ -123,7 +123,6 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
graph/passes/compile_nodes_pass.cc \ | |||
graph/passes/constant_folding_pass.cc \ | |||
graph/passes/constant_fuse_same_pass.cc \ | |||
graph/passes/const_pass.cc \ | |||
graph/passes/control_trigger_pass.cc \ | |||
graph/passes/dimension_adjust_pass.cc \ | |||
graph/passes/dimension_compute_pass.cc \ | |||
@@ -224,7 +224,6 @@ Status ModelBuilder::AdjustConstWeightSize(const ge::NodePtr &node, size_t &mem_ | |||
GeTensorDesc &tensor_desc = weight->MutableTensorDesc(); | |||
size_t output_size = weight->GetData().size(); | |||
TensorUtils::SetDataOffset(tensor_desc, mem_offset); | |||
GELOGD("Node: %s, weight size: %zu.", node->GetName().c_str(), output_size); | |||
mem_offset += output_size; | |||
} | |||
return SUCCESS; | |||
@@ -56,7 +56,6 @@ | |||
#include "graph/passes/cond_remove_pass.h" | |||
#include "graph/passes/constant_folding_pass.h" | |||
#include "graph/passes/constant_fuse_same_pass.h" | |||
#include "graph/passes/const_pass.cc" | |||
#include "graph/passes/control_trigger_pass.h" | |||
#include "graph/passes/ctrl_edge_transfer_pass.h" | |||
#include "graph/passes/dimension_adjust_pass.h" | |||
@@ -2138,7 +2137,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||
TransposeTransDataPass transpose_transdata_pass; | |||
TransOpSymmetryEliminationPass symmetry_elimination_pass; | |||
DimensionComputePass dimension_compute_pass; | |||
ConstPass const_pass; | |||
names_to_passes.emplace_back("EnterPass", &enter_pass); | |||
names_to_passes.emplace_back("AddNPass", &addn_pass); | |||
names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); | |||
@@ -2152,7 +2150,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||
names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); | |||
names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | |||
names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); | |||
names_to_passes.emplace_back("ConstPass", &const_pass); | |||
GE_TIMESTAMP_START(names_to_passes); | |||
ret = GEPass(compute_graph).Run(names_to_passes); | |||
GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2"); | |||
@@ -2193,8 +2190,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass", | |||
new (std::nothrow) VariableRefUselessControlOutDeletePass)) | |||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass)) | |||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::CommonSubexpressionEliminationPass", | |||
new (std::nothrow) CommonSubexpressionEliminationPass)); | |||
if (options_.train_graph_flag) { | |||
// Priority: The GlobalStepInsertPass should work before graph partitioner. | |||
// Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory | |||
@@ -18,8 +18,6 @@ | |||
#include "ge/ge_api_types.h" | |||
#include "graph/common/omg_util.h" | |||
using std::string; | |||
namespace ge { | |||
Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { | |||
GELOGD("AttachStreamLabelPass Enter."); | |||
@@ -189,10 +187,21 @@ Status AttachStreamLabelPass::UpdateEnterNode() { | |||
} | |||
std::stack<NodePtr> enter_nodes; | |||
std::string batch_label; | |||
for (const auto &enter_node : pair.second) { | |||
enter_nodes.emplace(enter_node); | |||
std::string tmp_label; | |||
(void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); | |||
if (!tmp_label.empty()) { | |||
if (batch_label.empty()) { | |||
batch_label = tmp_label; | |||
} else if (batch_label != tmp_label) { | |||
GELOGE(FAILED, "multi batch_label exist, label1=%s, label2=%s.", batch_label.c_str(), tmp_label.c_str()); | |||
return FAILED; | |||
} | |||
} | |||
} | |||
if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { | |||
if (UpdateLoopBranch(enter_nodes, active_label_list[0], batch_label) != SUCCESS) { | |||
GELOGE(FAILED, "Update stream_label for loop_branch failed."); | |||
return FAILED; | |||
} | |||
@@ -217,7 +226,10 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||
} | |||
for (const auto &enter_node : enter_nodes) { | |||
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | |||
GE_CHECK_NOTNULL(enter_node->GetOpDesc()); | |||
if (enter_node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL)) { | |||
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -229,7 +241,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||
/// @param [in] batch_label | |||
/// @return Status | |||
/// | |||
Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const string &stream_label) { | |||
Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label, | |||
const std::string &batch_label) { | |||
std::stack<NodePtr> nodes(enter_nodes); | |||
NodePtr cur_node = nullptr; | |||
while (!nodes.empty()) { | |||
@@ -238,6 +251,11 @@ Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_ | |||
for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { | |||
OpDescPtr out_desc = out_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(out_desc); | |||
std::string tmp_label; | |||
(void)AttrUtils::GetStr(out_desc, ATTR_NAME_BATCH_LABEL, tmp_label); | |||
if (!tmp_label.empty() && (tmp_label != batch_label)) { | |||
continue; | |||
} | |||
std::string out_type = out_desc->GetType(); | |||
bool need_skip = | |||
out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) || | |||
@@ -58,9 +58,11 @@ class AttachStreamLabelPass : public GraphPass { | |||
/// @brief Update stream_label for loop_branch | |||
/// @param [in] enter_nodes | |||
/// @param [in] stream_label | |||
/// @param [in] batch_label | |||
/// @return Status | |||
/// | |||
static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label); | |||
static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label, | |||
const std::string &batch_label); | |||
/// | |||
/// @brief Update stream_label start with enter nodes | |||
@@ -96,7 +96,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder | |||
node->GetName().c_str(), node->GetType().c_str()); | |||
continue; | |||
} | |||
if (node_to_re_pass->IsAllInNodesSeen(nodes_seen) || node_to_re_pass->GetType() == ENTER) { | |||
if (node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { | |||
GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str()); | |||
nodes_re_pass.insert(node_to_re_pass); | |||
} else { | |||
@@ -58,8 +58,7 @@ std::string GetCseKey(const NodePtr &node) { | |||
/// To avoid delete wrong nodes(e.g. stateful nodes), | |||
/// only nodes have folding kernel will be considered for the CSE process | |||
bool IsNodeSupportCse(const NodePtr &node) { | |||
if (HostCpuEngine::CheckSupported(NodeUtils::GetNodeType(*node)) || node->GetType() == CONSTANT || | |||
node->GetType() == CONSTANTOP) { | |||
if (HostCpuEngine::CheckSupported(NodeUtils::GetNodeType(*node))) { | |||
return true; | |||
} | |||
return folding_pass::GetKernelByType(node) != nullptr; | |||
@@ -1,55 +0,0 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/passes/const_pass.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "framework/common/debug/log.h" | |||
namespace ge { | |||
Status ConstPass::Run(NodePtr &node) { | |||
GE_CHECK_NOTNULL(node); | |||
if ((node->GetType() != CONSTANT) && (node->GetType() != CONSTANTOP)) { | |||
return SUCCESS; | |||
} | |||
GELOGD("ConstPass running, node: %s.", node->GetName().c_str()); | |||
// const has no control input | |||
if (node->GetInControlNodes().empty()) { | |||
auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||
if (out_ctrl_anchor != nullptr) { | |||
GELOGD("Node: %s unlink all out control edge.", node->GetName().c_str()); | |||
out_ctrl_anchor->UnlinkAll(); | |||
} | |||
if (node->GetOutAllNodes().empty()) { | |||
// it is an isolated const, just remove it. | |||
GELOGD("Delete isolated const: %s.", node->GetName().c_str()); | |||
auto graph = node->GetOwnerComputeGraph(); | |||
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { | |||
GELOGE(FAILED, "Remove const %s failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
AddNodeDeleted(node); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -1,29 +0,0 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GRAPH_PASSES_CONST_PASS_H_ | |||
#define GE_GRAPH_PASSES_CONST_PASS_H_ | |||
#include "graph/passes/base_pass.h" | |||
namespace ge { | |||
class ConstPass : public BaseNodePass { | |||
public: | |||
Status Run(NodePtr &node) override; | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_PASSES_CONST_PASS_H_ |
@@ -80,71 +80,7 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { | |||
} | |||
} | |||
ret = DealWithInNodes(node); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "DealWithInNodes of %s failed.", node->GetName().c_str()); | |||
return ret; | |||
} | |||
std::vector<int> data_relink_io_map = {kDataInputIndex}; | |||
return IsolateAndDeleteNode(node, data_relink_io_map); | |||
} | |||
Status DimensionAdjustPass::DealWithInNodes(NodePtr &node) { | |||
GE_CHECK_NOTNULL(node); | |||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
auto graph = node->GetOwnerComputeGraph(); | |||
auto in_data_anchors = node->GetAllInDataAnchors(); | |||
for (auto &in_data_anchor : in_data_anchors) { | |||
if (in_data_anchor == nullptr) { | |||
continue; | |||
} | |||
auto in_node_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
if (in_node_anchor == nullptr) { | |||
continue; | |||
} | |||
auto in_node = in_node_anchor->GetOwnerNode(); | |||
if (in_node->GetType() == SWITCHN) { | |||
GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); | |||
auto identity_name = node->GetName() + "_ctrl_identity_" + std::to_string(in_data_anchor->GetIdx()); | |||
auto identity = | |||
AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph); | |||
GE_CHECK_NOTNULL(identity); | |||
GELOGI("Create new identity node[%s] success.", identity->GetName().c_str()); | |||
GE_CHK_STATUS_RET(GraphUtils::AddEdge(in_node_anchor, identity->GetInDataAnchor(0))) | |||
GE_CHECK_NOTNULL(identity->GetOutControlAnchor()); | |||
if (identity->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor())) { | |||
continue; | |||
} | |||
GE_CHK_STATUS_RET(GraphUtils::AddEdge(identity->GetOutControlAnchor(), node->GetInControlAnchor())) | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
NodePtr DimensionAdjustPass::AddIdentityNodeToGraph(const string &name, const GeTensorDesc &tensor, | |||
ComputeGraphPtr &graph) { | |||
if (graph == nullptr) { | |||
GELOGE(INTERNAL_ERROR, "Comput graph ptr is null in creating identity node."); | |||
return nullptr; | |||
} | |||
OpDescPtr desc = MakeShared<OpDesc>("", ""); | |||
if (desc == nullptr) { | |||
GELOGE(MEMALLOC_FAILED, "Failed to create op desc."); | |||
return nullptr; | |||
} | |||
desc->SetName(name); | |||
desc->SetType(IDENTITY); | |||
auto ret = desc->AddInputDesc(tensor); | |||
auto ret2 = desc->AddOutputDesc(tensor); | |||
if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) { | |||
GELOGE(INTERNAL_ERROR, "Failed to add input/output desc in creating identity."); | |||
return nullptr; | |||
} | |||
return graph->AddNodeFront(desc); | |||
} | |||
} // namespace ge |
@@ -34,10 +34,6 @@ namespace ge { | |||
class DimensionAdjustPass : public BaseNodePass { | |||
public: | |||
Status Run(ge::NodePtr &node) override; | |||
private: | |||
Status DealWithInNodes(ge::NodePtr &node); | |||
NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tensor, ComputeGraphPtr &graph); | |||
}; | |||
} // namespace ge | |||
@@ -23,7 +23,6 @@ | |||
namespace { | |||
const size_t kOutNodesNum = 1; | |||
const size_t kInCtrlNodesNum = 1; | |||
} | |||
namespace ge { | |||
@@ -56,7 +55,6 @@ Status EnterPass::Run(NodePtr &node) { | |||
if (out_ctrl_node == nullptr) { | |||
continue; | |||
} | |||
GELOGD("Remove control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str()); | |||
if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||
GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(), | |||
out_ctrl_node->GetName().c_str()); | |||
@@ -64,12 +62,8 @@ Status EnterPass::Run(NodePtr &node) { | |||
} | |||
} | |||
} else { | |||
if (OptimizeEnterWithOnlyOutData(node, in_node) != SUCCESS) { | |||
GELOGE(FAILED, "Optimize enter node[%s] with only out data node failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
if (UnlinkCtrlEdgeBeforeConst(node) != SUCCESS) { | |||
GELOGE(FAILED, "Unlink control edge before const of node[%s]'s out nodes failed.", node->GetName().c_str()); | |||
if (OptimizeEnter(node, in_node) != SUCCESS) { | |||
GELOGE(FAILED, "Optimize enter node[%s] failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
@@ -78,7 +72,7 @@ Status EnterPass::Run(NodePtr &node) { | |||
return SUCCESS; | |||
} | |||
Status EnterPass::OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node) { | |||
Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { | |||
if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { | |||
return SUCCESS; | |||
} | |||
@@ -89,45 +83,17 @@ Status EnterPass::OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node) | |||
} | |||
GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); | |||
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))) | |||
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); | |||
const auto &out_data_anchor = node->GetOutDataAnchor(0); | |||
GE_CHECK_NOTNULL(out_data_anchor); | |||
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)) | |||
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)) | |||
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); | |||
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); | |||
} | |||
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)) | |||
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)); | |||
AddNodeDeleted(node); | |||
AddRePassNodesWithInOut(in_node); | |||
return SUCCESS; | |||
} | |||
Status EnterPass::UnlinkCtrlEdgeBeforeConst(NodePtr &node) { | |||
auto out_ctrl_nodes = node->GetOutControlNodes(); | |||
if (out_ctrl_nodes.empty()) { | |||
return SUCCESS; | |||
} | |||
auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||
GE_CHECK_NOTNULL(out_ctrl_anchor); | |||
for (auto &out_ctrl_node : out_ctrl_nodes) { | |||
GE_CHECK_NOTNULL(out_ctrl_node); | |||
if ((out_ctrl_node->GetType() != CONSTANT) && (out_ctrl_node->GetType() != CONSTANTOP)) { | |||
continue; | |||
} | |||
auto in_ctrl_nodes = out_ctrl_node->GetInControlNodes(); | |||
if (in_ctrl_nodes.size() != kInCtrlNodesNum) { | |||
continue; | |||
} | |||
GE_CHK_STATUS_RET(out_ctrl_anchor->Unlink(out_ctrl_node->GetInControlAnchor())) | |||
auto out_nodes_of_const = out_ctrl_node->GetOutAllNodes(); | |||
for (auto &out_node_of_const : out_nodes_of_const) { | |||
if (!out_ctrl_anchor->IsLinkedWith(out_node_of_const->GetInControlAnchor())) { | |||
GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(out_node_of_const->GetInControlAnchor())) | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -25,8 +25,7 @@ class EnterPass : public BaseNodePass { | |||
Status Run(NodePtr &node) override; | |||
private: | |||
Status OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node); | |||
Status UnlinkCtrlEdgeBeforeConst(NodePtr &node); | |||
Status OptimizeEnter(NodePtr &node, NodePtr &in_node); | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_PASSES_ENTER_PASS_H_ |
@@ -173,7 +173,10 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) { | |||
continue; | |||
} | |||
auto in_node = in_node_anchor->GetOwnerNode(); | |||
if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH) || (in_node->GetType() == SWITCHN)) { | |||
if (in_node == nullptr) { | |||
continue; | |||
} | |||
if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) { | |||
GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); | |||
auto ret = in_node_anchor->Unlink(in_data_anchor); | |||
if (ret != SUCCESS) { | |||
@@ -89,6 +89,16 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co | |||
GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); | |||
} | |||
if (merge_op_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) { | |||
string batch_label; | |||
(void)AttrUtils::GetStr(merge_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||
if (!batch_label.empty()) { | |||
auto stream_merge_desc = stream_merge->GetOpDesc(); | |||
GE_CHECK_NOTNULL(stream_merge_desc); | |||
(void)AttrUtils::SetStr(stream_merge_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||
} | |||
} | |||
return AddActiveNodes(graph, stream_merge); | |||
} | |||
@@ -19,8 +19,6 @@ | |||
#include "common/ge/ge_util.h" | |||
#include "graph/common/omg_util.h" | |||
using std::string; | |||
namespace ge { | |||
Status NextIterationPass::Run(ComputeGraphPtr graph) { | |||
GELOGD("NextIterationPass Enter"); | |||
@@ -37,6 +35,10 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { | |||
return INTERNAL_ERROR; | |||
} | |||
} | |||
if (GroupWithNoBatch(graph) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Group enter_nodes failed without batch_label attr."); | |||
return INTERNAL_ERROR; | |||
} | |||
if (FindWhileGroups() != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Find while groups failed."); | |||
@@ -71,22 +73,75 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { | |||
return FAILED; | |||
} | |||
string batch_label; | |||
if (ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||
frame_name += batch_label; | |||
std::string batch_label; | |||
(void)ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label); | |||
if (batch_label.empty()) { | |||
auto frame_iter = frame_enter_map_.find(frame_name); | |||
if (frame_iter == frame_enter_map_.end()) { | |||
std::vector<NodePtr> enter_nodes; | |||
enter_nodes.emplace_back(enter_node); | |||
frame_enter_map_[frame_name] = enter_nodes; | |||
} else { | |||
frame_iter->second.emplace_back(enter_node); | |||
} | |||
return SUCCESS; | |||
} | |||
auto iter = loop_group_map_.find(frame_name); | |||
if (iter == loop_group_map_.end()) { | |||
auto group_iter = loop_group_map_.find(frame_name); | |||
if (group_iter == loop_group_map_.end()) { | |||
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); | |||
if (loop_group == nullptr) { | |||
GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); | |||
return FAILED; | |||
} | |||
loop_group->enter_nodes.emplace_back(enter_node); | |||
loop_group_map_[frame_name] = loop_group; | |||
loop_group_map_[frame_name][batch_label] = loop_group; | |||
} else { | |||
iter->second->enter_nodes.emplace_back(enter_node); | |||
auto batch_iter = group_iter->second.find(batch_label); | |||
if (batch_iter == group_iter->second.end()) { | |||
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); | |||
if (loop_group == nullptr) { | |||
GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); | |||
return FAILED; | |||
} | |||
loop_group->enter_nodes.emplace_back(enter_node); | |||
group_iter->second[batch_label] = loop_group; | |||
} else { | |||
batch_iter->second->enter_nodes.emplace_back(enter_node); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
/// | |||
/// @brief Group Enter nodes without batch_label attr | |||
/// @param [in] compute_graph | |||
/// @return Status | |||
/// | |||
Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) { | |||
if (frame_enter_map_.empty()) { | |||
GELOGI("All enter nodes in graph %s has batch_label attr.", graph->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
for (const auto &item : frame_enter_map_) { | |||
const std::string &frame_name = item.first; | |||
auto iter = loop_group_map_.find(frame_name); | |||
if (iter == loop_group_map_.end()) { | |||
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>(); | |||
if (loop_group == nullptr) { | |||
GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); | |||
return FAILED; | |||
} | |||
loop_group->enter_nodes = item.second; | |||
loop_group_map_[frame_name][""] = loop_group; | |||
} else { | |||
for (auto &batch_item : iter->second) { | |||
for (const auto &enter_node : item.second) { | |||
batch_item.second->enter_nodes.emplace_back(enter_node); | |||
} | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
@@ -99,39 +154,55 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { | |||
Status NextIterationPass::FindWhileGroups() { | |||
for (const auto &loop_group_iter : loop_group_map_) { | |||
const std::string &frame_name = loop_group_iter.first; | |||
for (const auto &enter_node : loop_group_iter.second->enter_nodes) { | |||
for (const auto &out_node : enter_node->GetOutAllNodes()) { | |||
const string &type = out_node->GetType(); | |||
if ((type != MERGE) && (type != REFMERGE)) { | |||
continue; | |||
} | |||
NodePtr next_node = nullptr; | |||
if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s", frame_name.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); | |||
NodePtr switch_node = nullptr; | |||
if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
if (switch_node == nullptr) { | |||
continue; | |||
} | |||
NodePtr loop_cond = nullptr; | |||
if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
if (loop_group_iter.second->loop_cond == nullptr) { | |||
loop_group_iter.second->loop_cond = loop_cond; | |||
} else if (loop_group_iter.second->loop_cond != loop_cond) { | |||
GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str()); | |||
return FAILED; | |||
for (const auto &batch_iter : loop_group_iter.second) { | |||
const std::string &batch_label = batch_iter.first; | |||
for (const auto &enter_node : batch_iter.second->enter_nodes) { | |||
for (const auto &out_node : enter_node->GetOutAllNodes()) { | |||
GELOGI("Find while_group for enter_node %s, frame_name:%s, batch_label:%s.", enter_node->GetName().c_str(), | |||
frame_name.c_str(), batch_label.c_str()); | |||
if ((out_node->GetType() != MERGE) && (out_node->GetType() != REFMERGE)) { | |||
continue; | |||
} | |||
std::string tmp_label; | |||
GE_CHECK_NOTNULL(out_node->GetOpDesc()); | |||
(void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); | |||
bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); | |||
if (need_skip) { | |||
continue; | |||
} | |||
NodePtr next_node = nullptr; | |||
if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, | |||
"Get NextIteration node failed: inputs of Merge should be Enter/NextIteration, current_Merge=%s", | |||
out_node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); | |||
NodePtr switch_node = nullptr; | |||
if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Get Switch node failed: output of Merge should be Switch, current_Merge=%s", | |||
out_node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
if (switch_node == nullptr) { | |||
continue; | |||
} | |||
NodePtr loop_cond = nullptr; | |||
if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, | |||
"Get LoopCond node failed: pred input of Switch should be LoopCond, current_Switch=%s", | |||
switch_node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
if (batch_iter.second->loop_cond == nullptr) { | |||
batch_iter.second->loop_cond = loop_cond; | |||
} else if (batch_iter.second->loop_cond != loop_cond) { | |||
GELOGE(FAILED, "Multi LoopCond nodes exist."); | |||
return FAILED; | |||
} | |||
} | |||
} | |||
} | |||
@@ -152,17 +223,19 @@ bool NextIterationPass::VerifyWhileGroup() { | |||
GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); | |||
return false; | |||
} | |||
if (loop_group_iter.second->loop_cond == nullptr) { | |||
GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); | |||
return false; | |||
} | |||
for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) { | |||
if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { | |||
GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", | |||
frame_name.c_str()); | |||
for (const auto &batch_iter : loop_group_iter.second) { | |||
if (batch_iter.second->loop_cond == nullptr) { | |||
GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); | |||
return false; | |||
} | |||
for (const auto &pair_iter : batch_iter.second->merge_next_pairs) { | |||
if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { | |||
GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", | |||
frame_name.c_str()); | |||
return false; | |||
} | |||
} | |||
} | |||
} | |||
@@ -176,53 +249,56 @@ bool NextIterationPass::VerifyWhileGroup() { | |||
/// | |||
Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
for (const auto &loop_cond_iter : loop_group_map_) { | |||
const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); | |||
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | |||
// Create Active node, Enter->Active->Merge, NextIteration->Active->Merge | |||
NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); | |||
NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); | |||
if ((enter_active == nullptr) || (next_active == nullptr)) { | |||
GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
for (const auto &enter_node : loop_cond_iter.second->enter_nodes) { | |||
// Enter --> Active | |||
if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Add control edge from %s to %s failed.", enter_node->GetName().c_str(), | |||
enter_active->GetName().c_str()); | |||
for (const auto &batch_iter : loop_cond_iter.second) { | |||
const std::string &cond_name = batch_iter.second->loop_cond->GetName(); | |||
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | |||
// Create Active node, Enter->Active->Merge, NextIteration->Active->Merge | |||
NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); | |||
NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); | |||
if ((enter_active == nullptr) || (next_active == nullptr)) { | |||
GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
} | |||
for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | |||
NodePtr merge_node = pair.first; | |||
NodePtr next_node = pair.second; | |||
// Active --> Merge | |||
if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||
return INTERNAL_ERROR; | |||
for (const auto &enter_node : batch_iter.second->enter_nodes) { | |||
// Enter --> Active | |||
if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != | |||
GRAPH_SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
} | |||
// NextIteration --> Active | |||
if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||
return INTERNAL_ERROR; | |||
for (const auto &pair : batch_iter.second->merge_next_pairs) { | |||
NodePtr merge_node = pair.first; | |||
NodePtr next_node = pair.second; | |||
// Active --> Merge | |||
if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != | |||
GRAPH_SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
// NextIteration --> Active | |||
if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Add control edge failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
// break link between NextIteration and Merge | |||
if (BreakNextIteration(next_node, merge_node) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); | |||
return INTERNAL_ERROR; | |||
} | |||
} | |||
// break link between NextIteration and Merge | |||
if (BreakNextIteration(next_node, merge_node) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); | |||
if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | |||
(SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { | |||
GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
} | |||
if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | |||
(SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { | |||
GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); | |||
return INTERNAL_ERROR; | |||
} | |||
} | |||
return SUCCESS; | |||
@@ -289,11 +365,12 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & | |||
/// @param [in] node | |||
/// @param [in] target_type | |||
/// @param [in] is_input | |||
/// @param [in] batch_label | |||
/// @param [out] target_node | |||
/// @return Status | |||
/// | |||
Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, | |||
NodePtr &target_node) { | |||
const std::string &batch_label, NodePtr &target_node) { | |||
if (node == nullptr) { | |||
GELOGE(PARAM_INVALID, "node is null."); | |||
return PARAM_INVALID; | |||
@@ -310,6 +387,12 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string | |||
} | |||
for (const auto &tmp_node : nodes) { | |||
std::string tmp_label; | |||
(void)AttrUtils::GetStr(tmp_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); | |||
bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); | |||
if (need_skip) { | |||
continue; | |||
} | |||
const std::string type = tmp_node->GetType(); | |||
if ((target_type == LOOPCOND) && (type == target_type)) { | |||
target_node = tmp_node; | |||
@@ -332,6 +415,7 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string | |||
/// @return SUCCESS | |||
/// | |||
Status NextIterationPass::ClearStatus() { | |||
frame_enter_map_.clear(); | |||
loop_group_map_.clear(); | |||
return SUCCESS; | |||
} | |||
@@ -46,6 +46,13 @@ class NextIterationPass : public GraphPass { | |||
/// | |||
Status GroupEnterNode(const NodePtr &enter_node); | |||
/// | |||
/// @brief Group Enter nodes without batch_label attr | |||
/// @param [in] compute_graph | |||
/// @return Status | |||
/// | |||
Status GroupWithNoBatch(const ComputeGraphPtr &graph); | |||
/// | |||
/// @brief Find while groups | |||
/// @return Status | |||
@@ -90,10 +97,13 @@ class NextIterationPass : public GraphPass { | |||
/// @param [out] target_node | |||
/// @return Status | |||
/// | |||
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); | |||
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, | |||
const std::string &batch_label, NodePtr &target_node); | |||
// map<frame_name, LoopCondGroup> | |||
std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | |||
// map<frame_name, vector<enter_node>> | |||
std::unordered_map<std::string, std::vector<NodePtr>> frame_enter_map_; | |||
// map<frame_name, map<batch_label, LoopCondGroup>> | |||
std::unordered_map<std::string, std::unordered_map<std::string, LoopCondGroupPtr>> loop_group_map_; | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ |
@@ -44,8 +44,6 @@ | |||
using std::set; | |||
using std::string; | |||
using std::vector; | |||
using std::map; | |||
using std::queue; | |||
namespace ge { | |||
namespace multibatch { | |||
@@ -59,15 +57,10 @@ const int kDataInIndex = 0; | |||
const int kMergeDataOutIndex = 0; | |||
const int kStaticOutput = -1; | |||
const int kDivisionConst = 2; | |||
const int32_t kOneInDataNode = 1; | |||
const int32_t kFindNoMatch = 0; | |||
inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } | |||
inline bool IsEnterType(const string &node_type) { return (node_type == ENTER) || (node_type == REFENTER); } | |||
const set<string> unchange_types({CONSTANT, CONSTANTOP, ENTER, REFENTER}); | |||
inline bool IsGetNextType(const NodePtr &node) { | |||
std::string original_type; | |||
GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, | |||
@@ -225,6 +218,12 @@ Status MultiBatchGraphCopyer::CopyGraph() { | |||
return ret; | |||
} | |||
ret = InsertIdentityAfterSwitchN(); | |||
if (ret != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Failed to insert identity nodes after switchn node."); | |||
return INTERNAL_ERROR; | |||
} | |||
GELOGI("Begin to remove useless nodes by prune pass after copy process"); | |||
PrunePass prune_pass; | |||
ret = prune_pass.Run(graph_); | |||
@@ -241,18 +240,6 @@ Status MultiBatchGraphCopyer::Init() { | |||
return ret; | |||
} | |||
ret = RelinkConstCtrlEdge(); | |||
if (ret != SUCCESS) { | |||
GELOGE(FAILED, "Relink const's control edge failed."); | |||
return FAILED; | |||
} | |||
ret = ExtractUnchangedStructureOutofCycle(); | |||
if (ret != SUCCESS) { | |||
GELOGE(FAILED, "Extract unchanged structure out of cycle failed."); | |||
return FAILED; | |||
} | |||
for (auto &node : graph_->GetAllNodes()) { | |||
origin_all_nodes_.emplace_back(node); | |||
if (IsDataLikeType(node->GetType())) { | |||
@@ -265,281 +252,6 @@ Status MultiBatchGraphCopyer::Init() { | |||
return SUCCESS; | |||
} | |||
Status MultiBatchGraphCopyer::RelinkConstCtrlEdge() { | |||
for (auto &node : graph_->GetAllNodes()) { | |||
GE_CHECK_NOTNULL(node); | |||
if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) { | |||
if (node->GetOutDataNodes().empty()) { | |||
continue; | |||
} | |||
if (!node->GetInControlNodes().empty()) { | |||
auto in_ctrl_nodes = node->GetInControlNodes(); | |||
auto out_nodes = node->GetOutAllNodes(); | |||
bool has_merge = false; | |||
for (const auto &out_node : out_nodes) { | |||
GE_CHECK_NOTNULL(out_node); | |||
if (out_node->GetType() == MERGE || out_node->GetType() == REFMERGE) { | |||
has_merge = true; | |||
break; | |||
} | |||
} | |||
if (has_merge) { | |||
continue; | |||
} | |||
auto in_ctrl_anchor = node->GetInControlAnchor(); | |||
GE_CHECK_NOTNULL(in_ctrl_anchor); | |||
in_ctrl_anchor->UnlinkAll(); | |||
for (auto &in_ctrl_node : in_ctrl_nodes) { | |||
auto out_ctrl_anchor_of_in_ctrl_node = in_ctrl_node->GetOutControlAnchor(); | |||
GE_CHECK_NOTNULL(out_ctrl_anchor_of_in_ctrl_node); | |||
for (auto &out_node : out_nodes) { | |||
if (IsEnterType(out_node->GetType())) { | |||
continue; | |||
} | |||
if (!out_ctrl_anchor_of_in_ctrl_node->IsLinkedWith(out_node->GetInControlAnchor())) { | |||
GE_CHK_STATUS_RET(out_ctrl_anchor_of_in_ctrl_node->LinkTo(out_node->GetInControlAnchor())) | |||
} | |||
} | |||
} | |||
} | |||
auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||
if (out_ctrl_anchor != nullptr) { | |||
out_ctrl_anchor->UnlinkAll(); | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status MultiBatchGraphCopyer::ExtractUnchangedStructureOutofCycle() { | |||
map<string, vector<NodePtr>> frame_enter; | |||
if (GetEnterNodesGroupByFrame(frame_enter) != SUCCESS) { | |||
GELOGE(FAILED, "Get enter nodes grouped by frame_name failed."); | |||
return FAILED; | |||
} | |||
queue<NodePtr> nodes_to_extract; | |||
if (GetNodeNeedExtract(frame_enter, nodes_to_extract) != SUCCESS) { | |||
GELOGE(FAILED, "Get nodes needed to extract failed."); | |||
return FAILED; | |||
} | |||
while (!nodes_to_extract.empty()) { | |||
auto node = nodes_to_extract.front(); | |||
nodes_to_extract.pop(); | |||
OpDescPtr enter_desc = nullptr; | |||
if (MoveInEntersInDataAnchorDown(node, enter_desc) != SUCCESS) { | |||
GELOGE(FAILED, "Move in enter nodes' in data anchors down of %s failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
set<NodePtr> out_nodes; | |||
if (InsertEnterAfterNode(node, enter_desc, out_nodes) != SUCCESS) { | |||
GELOGE(FAILED, "Insert enter node after %s failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
if (MoveCtrlEdgeToOutNodes(node, out_nodes) != SUCCESS) { | |||
GELOGE(FAILED, "Move %s's control edge to out nodes failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
for (auto &out_node : out_nodes) { | |||
GE_CHECK_NOTNULL(out_node); | |||
if (AllInDataNodesUnchangeAndNoMergeOut(out_node)) { | |||
nodes_to_extract.push(out_node); | |||
} | |||
} | |||
} | |||
if (DeleteEnterWithoutDataOut() != SUCCESS) { | |||
GELOGE(FAILED, "Delete enter node without out data nodes failed."); | |||
return FAILED; | |||
} | |||
return SUCCESS; | |||
} | |||
Status MultiBatchGraphCopyer::GetEnterNodesGroupByFrame(map<string, vector<NodePtr>> &frame_enter) { | |||
for (auto &node : graph_->GetAllNodes()) { | |||
GE_CHECK_NOTNULL(node); | |||
if (IsEnterType(node->GetType())) { | |||
if (!node->GetInControlNodes().empty() || !node->GetOutControlNodes().empty()) { | |||
continue; | |||
} | |||
auto op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
string frame_name; | |||
if (!AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { | |||
GELOGE(FAILED, "Get attr frame_name of enter[%] failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
frame_enter[frame_name].emplace_back(node); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status MultiBatchGraphCopyer::GetNodeNeedExtract(const map<string, vector<NodePtr>> &frame_enter, | |||
queue<NodePtr> &nodes_to_extract) { | |||
for (const auto &one_group : frame_enter) { | |||
auto enters = one_group.second; | |||
for (const auto &enter : enters) { | |||
auto out_data_nodes = enter->GetOutDataNodes(); | |||
for (const auto &out_data_node : out_data_nodes) { | |||
GE_CHECK_NOTNULL(out_data_node); | |||
if (AllInDataNodesUnchangeAndNoMergeOut(out_data_node)) { | |||
nodes_to_extract.push(out_data_node); | |||
} | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
bool MultiBatchGraphCopyer::AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node) { | |||
auto out_data_nodes = node->GetOutDataNodes(); | |||
for (const auto &out_data_node : out_data_nodes) { | |||
if (out_data_node == nullptr) { | |||
return false; | |||
} | |||
if (out_data_node->GetType() == MERGE || out_data_node->GetType() == REFMERGE) { | |||
return false; | |||
} | |||
} | |||
auto in_data_nodes = node->GetInDataNodes(); | |||
if (in_data_nodes.size() == kOneInDataNode) { | |||
return true; | |||
} | |||
for (const auto &in_data_node : in_data_nodes) { | |||
if (in_data_node == nullptr) { | |||
return false; | |||
} | |||
if (unchange_types.count(in_data_node->GetType()) == kFindNoMatch) { | |||
return false; | |||
} | |||
} | |||
return true; | |||
} | |||
Status MultiBatchGraphCopyer::MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc) { | |||
auto in_data_anchors = node->GetAllInDataAnchors(); | |||
for (auto &in_data_anchor : in_data_anchors) { | |||
auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(peer_out_data_anchor); | |||
auto peer_in_data_node = peer_out_data_anchor->GetOwnerNode(); | |||
if (IsEnterType(peer_in_data_node->GetType())) { | |||
GE_CHK_STATUS_RET(peer_out_data_anchor->Unlink(in_data_anchor)) | |||
GELOGD("Unlink data edge from %s to %s.", peer_in_data_node->GetName().c_str(), node->GetName().c_str()); | |||
auto enter_in_data_anchors = peer_in_data_node->GetAllInDataAnchors(); | |||
for (auto &enter_in_data_anchor : enter_in_data_anchors) { | |||
auto peer_out_data_anchor_of_enter = enter_in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(peer_out_data_anchor_of_enter); | |||
if (peer_out_data_anchor_of_enter->IsLinkedWith(in_data_anchor)) { | |||
continue; | |||
} | |||
GE_CHK_STATUS_RET(peer_out_data_anchor_of_enter->LinkTo(in_data_anchor)) | |||
GELOGD("Relink data edge from %s to %s.", peer_out_data_anchor_of_enter->GetOwnerNode()->GetName().c_str(), | |||
node->GetName().c_str()); | |||
} | |||
enter_desc = peer_in_data_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(enter_desc); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status MultiBatchGraphCopyer::InsertEnterAfterNode(NodePtr &node, const OpDescPtr ©_desc, set<NodePtr> &out_nodes) { | |||
if (copy_desc == nullptr) { | |||
return SUCCESS; | |||
} | |||
map<OutDataAnchorPtr, vector<std::pair<InDataAnchorPtr, NodePtr>>> outanchors_inanchors_nodes; | |||
auto out_data_anchors = node->GetAllOutDataAnchors(); | |||
for (auto &out_data_anchor : out_data_anchors) { | |||
auto peer_in_data_anchors = out_data_anchor->GetPeerInDataAnchors(); | |||
for (auto peer_in_data_anchor : peer_in_data_anchors) { | |||
GE_CHECK_NOTNULL(peer_in_data_anchor); | |||
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | |||
out_nodes.emplace(peer_in_data_node); | |||
outanchors_inanchors_nodes[out_data_anchor].emplace_back(std::make_pair(peer_in_data_anchor, peer_in_data_node)); | |||
} | |||
} | |||
int32_t i = 0; | |||
auto node_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(node_desc); | |||
// Insert one enter node after node's per out data anchor | |||
for (auto &outanchor_inanchors_nodes : outanchors_inanchors_nodes) { | |||
string name = node->GetName() + "_" + ENTER + "_" + std::to_string(i++); | |||
GELOGD("Create Enter op %s after %s.", name.c_str(), node->GetName().c_str()); | |||
auto enter_desc = AttrUtils::CopyOpDesc(copy_desc); | |||
enter_desc->SetName(name); | |||
GE_CHK_STATUS_RET( | |||
enter_desc->UpdateInputDesc("x", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx()))) | |||
GE_CHK_STATUS_RET( | |||
enter_desc->UpdateOutputDesc("y", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx()))) | |||
auto enter_node = graph_->AddNode(enter_desc); | |||
GE_CHECK_NOTNULL(enter_node); | |||
GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->LinkTo(enter_node->GetInDataAnchor(kDataInIndex))) | |||
GE_CHECK_NOTNULL(enter_node->GetOutDataAnchor(kDataInIndex)); | |||
for (auto &inanchor_node : outanchor_inanchors_nodes.second) { | |||
GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->Unlink(inanchor_node.first)) | |||
GE_CHK_STATUS_RET(enter_node->GetOutDataAnchor(kDataInIndex)->LinkTo(inanchor_node.first)) | |||
GELOGD("Unlink from %s to %s, link from %s to %s then to %s.", node->GetName().c_str(), | |||
inanchor_node.second->GetName().c_str(), node->GetName().c_str(), enter_node->GetName().c_str(), | |||
inanchor_node.second->GetName().c_str()); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
// Move node's in control edges to out data nodes | |||
Status MultiBatchGraphCopyer::MoveCtrlEdgeToOutNodes(NodePtr &node, set<NodePtr> &out_nodes) { | |||
auto in_ctrl_anchor = node->GetInControlAnchor(); | |||
GE_CHECK_NOTNULL(in_ctrl_anchor); | |||
auto peer_out_ctrl_anchors = in_ctrl_anchor->GetPeerOutControlAnchors(); | |||
for (auto &peer_out_ctrl_anchor : peer_out_ctrl_anchors) { | |||
GE_CHK_STATUS_RET(peer_out_ctrl_anchor->Unlink(in_ctrl_anchor)) | |||
GELOGD("Unlink control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), | |||
node->GetName().c_str()); | |||
for (auto &out_node : out_nodes) { | |||
auto in_ctrl_anchor_of_out_node = out_node->GetInControlAnchor(); | |||
GE_CHECK_NOTNULL(in_ctrl_anchor_of_out_node); | |||
if (!peer_out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor_of_out_node)) { | |||
GE_CHK_STATUS_RET(peer_out_ctrl_anchor->LinkTo(in_ctrl_anchor_of_out_node)) | |||
GELOGD("Link control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), | |||
out_node->GetName().c_str()); | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status MultiBatchGraphCopyer::DeleteEnterWithoutDataOut() { | |||
for (auto &node : graph_->GetAllNodes()) { | |||
GE_CHECK_NOTNULL(node); | |||
if (IsEnterType(node->GetType())) { | |||
auto out_nodes = node->GetOutAllNodes(); | |||
if (out_nodes.empty()) { | |||
GELOGD("Delete enter node: %s which has no output.", node->GetName().c_str()); | |||
GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {})) | |||
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node)) | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) { | |||
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | |||
GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(), | |||
@@ -585,9 +297,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { | |||
LabelStatusForGetNextSink(data); | |||
} | |||
} | |||
map<string, vector<NodePtr>> frame_enters; | |||
InitStatus(frame_enters); | |||
bool changed = true; | |||
// If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch | |||
while (changed) { | |||
@@ -597,13 +306,12 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { | |||
if (iter != origin_nodes_status_.end()) { | |||
continue; | |||
} | |||
for (auto &in_node : node->GetInDataNodes()) { | |||
if (origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end()) { | |||
if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) { | |||
origin_nodes_status_[node.get()] == kNodeInBatchBranch; | |||
ResetEnterStatus(frame_enters, node); | |||
changed = true; | |||
} | |||
for (auto &in_node : node->GetInAllNodes()) { | |||
bool is_in_batch = origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() && | |||
origin_nodes_status_[in_node.get()] == kNodeInBatchBranch; | |||
if (is_in_batch) { | |||
origin_nodes_status_[node.get()] = kNodeInBatchBranch; | |||
changed = true; | |||
break; | |||
} | |||
} | |||
@@ -612,45 +320,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() { | |||
return SUCCESS; | |||
} | |||
void MultiBatchGraphCopyer::InitStatus(map<string, vector<NodePtr>> &frame_enters) { | |||
for (const auto &node : origin_all_nodes_) { | |||
if (!IsEnterType(node->GetType())) { | |||
continue; | |||
} | |||
auto op_desc = node->GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
continue; | |||
} | |||
string frame_name; | |||
if (AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { | |||
frame_enters[frame_name].emplace_back(node); | |||
} | |||
} | |||
for (const auto &data : origin_data_nodes_) { | |||
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); | |||
if (!IsAllDimsPositive(data_shape.GetDims())) { | |||
origin_nodes_status_[data.get()] = kNodeInBatchBranch; | |||
} | |||
} | |||
} | |||
void MultiBatchGraphCopyer::ResetEnterStatus(map<string, vector<NodePtr>> &frame_enters, const NodePtr &node) { | |||
if (!IsEnterType(node->GetType())) { | |||
return; | |||
} | |||
for (const auto &frame_enter : frame_enters) { | |||
auto &enters = frame_enter.second; | |||
if (std::find(enters.begin(), enters.end(), node) != enters.end()) { | |||
for (const auto &enter : enters) { | |||
origin_nodes_status_[enter.get()] = kNodeInBatchBranch; | |||
} | |||
break; | |||
} | |||
} | |||
} | |||
Status MultiBatchGraphCopyer::LabelStatus() { | |||
if (LabelInBatchBranchStatus() != SUCCESS) { | |||
GELOGE(PARAM_INVALID, "Failed to label no in batch branch"); | |||
@@ -1691,6 +1360,52 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { | |||
return SUCCESS; | |||
} | |||
Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() { | |||
for (auto &node : graph_->GetAllNodes()) { | |||
if (node->GetType() != SWITCHN) { | |||
continue; | |||
} | |||
auto switchn_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(switchn_desc); | |||
size_t i = 0; | |||
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
for (auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
auto out_node = in_data_anchor->GetOwnerNode(); | |||
auto op_desc = out_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
if ((out_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { | |||
GELOGD("No need to insert identity between %s and %s.", node->GetName().c_str(), out_node->GetName().c_str()); | |||
continue; | |||
} | |||
auto identity_desc = MakeShared<OpDesc>(node->GetName() + "_identity_" + std::to_string(i), IDENTITY); | |||
GE_CHECK_NOTNULL(identity_desc); | |||
string batch_label; | |||
if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||
if (!AttrUtils::SetStr(identity_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { | |||
GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", identity_desc->GetName().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
auto data_desc = switchn_desc->GetOutputDesc(i); | |||
i++; | |||
GE_CHK_STATUS_RET(identity_desc->AddInputDesc("x", data_desc)); | |||
GE_CHK_STATUS_RET(identity_desc->AddOutputDesc("y", data_desc)); | |||
auto identity_node = graph_->AddNode(identity_desc); | |||
GE_CHECK_NOTNULL(identity_node); | |||
GE_CHK_STATUS_RET(out_data_anchor->LinkTo(identity_node->GetInDataAnchor(0))); | |||
GE_CHECK_NOTNULL(identity_node->GetOutControlAnchor()); | |||
GE_CHK_STATUS_RET(identity_node->GetOutControlAnchor()->LinkTo(out_node->GetInControlAnchor())); | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status ProcessMultiBatch(ComputeGraphPtr &graph) { | |||
const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE"); | |||
if (multi_batch_with_case != nullptr) { | |||
@@ -18,7 +18,6 @@ | |||
#include <map> | |||
#include <queue> | |||
#include <vector> | |||
#include <set> | |||
#include "external/ge/ge_api_error_codes.h" | |||
@@ -65,26 +64,12 @@ class MultiBatchGraphCopyer { | |||
private: | |||
Status Init(); | |||
Status CheckArguments(); | |||
Status RelinkConstCtrlEdge(); | |||
Status ExtractUnchangedStructureOutofCycle(); | |||
Status GetEnterNodesGroupByFrame(std::map<std::string, std::vector<NodePtr>> &frame_enter); | |||
Status GetNodeNeedExtract(const std::map<std::string, std::vector<NodePtr>> &frame_enter, | |||
std::queue<NodePtr> &nodes_to_extract); | |||
bool AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node); | |||
Status MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc); | |||
Status InsertEnterAfterNode(NodePtr &node, const OpDescPtr &enter_desc, std::set<NodePtr> &out_nodes); | |||
Status MoveCtrlEdgeToOutNodes(NodePtr &node, std::set<NodePtr> &out_nodes); | |||
Status DeleteEnterWithoutDataOut(); | |||
// label status for origin_all_nodes_ | |||
Status LabelStatus(); | |||
Status LabelInBatchBranchStatus(); | |||
void LabelStatusForData(const NodePtr &data); | |||
void LabelStatusForGetNextSink(const NodePtr &data); | |||
void InitStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters); | |||
void ResetEnterStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters, const NodePtr &node); | |||
// add nodes functions | |||
Status CreateNewNodes(); | |||
@@ -96,6 +81,7 @@ class MultiBatchGraphCopyer { | |||
Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index, | |||
std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn); | |||
Status InsertIdentityAfterSwitchN(); | |||
Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index); | |||
Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index); | |||