From: @zhou_lili Reviewed-by: Signed-off-by:tags/v1.2.0
@@ -202,7 +202,9 @@ set(TRAIN_SRC_LIST | |||
"graph/passes/compile_nodes_pass.cc" | |||
"graph/passes/constant_folding_pass.cc" | |||
"graph/passes/constant_fuse_same_pass.cc" | |||
"graph/passes/fuse_data_nodes_with_common_input_pass.cc" | |||
"graph/passes/remove_same_const_pass.cc" | |||
"graph/passes/no_data_out_const_elimination_pass.cc" | |||
"graph/passes/useless_control_out_remove_pass.cc" | |||
"graph/passes/control_trigger_pass.cc" | |||
"graph/passes/dimension_adjust_pass.cc" | |||
@@ -484,6 +486,7 @@ set(INFER_SRC_LIST | |||
"graph/passes/net_output_pass.cc" | |||
"graph/passes/replace_transshape_pass.cc" | |||
"graph/passes/constant_fuse_same_pass.cc" | |||
"graph/passes/fuse_data_nodes_with_common_input_pass.cc" | |||
"graph/passes/print_op_pass.cc" | |||
"graph/passes/no_use_reshape_remove_pass.cc" | |||
"graph/passes/iterator_op_pass.cc" | |||
@@ -580,6 +583,7 @@ set(INFER_SRC_LIST | |||
"graph/passes/addn_pass.cc" | |||
"graph/passes/common_subexpression_elimination_pass.cc" | |||
"graph/passes/remove_same_const_pass.cc" | |||
"graph/passes/no_data_out_const_elimination_pass.cc" | |||
"graph/passes/useless_control_out_remove_pass.cc" | |||
"graph/passes/transop_symmetry_elimination_pass.cc" | |||
"graph/passes/save_pass.cc" | |||
@@ -103,6 +103,7 @@ OMG_HOST_SRC_FILES := \ | |||
graph/passes/net_output_pass.cc \ | |||
graph/passes/replace_transshape_pass.cc \ | |||
graph/passes/constant_fuse_same_pass.cc \ | |||
graph/passes/fuse_data_nodes_with_common_input_pass.cc \ | |||
graph/passes/print_op_pass.cc \ | |||
graph/passes/no_use_reshape_remove_pass.cc \ | |||
graph/passes/iterator_op_pass.cc \ | |||
@@ -193,6 +194,7 @@ OMG_HOST_SRC_FILES := \ | |||
graph/passes/cond_pass.cc \ | |||
graph/passes/cond_remove_pass.cc \ | |||
graph/passes/remove_same_const_pass.cc \ | |||
graph/passes/no_data_out_const_elimination_pass.cc \ | |||
graph/passes/useless_control_out_remove_pass.cc \ | |||
graph/passes/for_pass.cc \ | |||
graph/passes/enter_pass.cc \ | |||
@@ -127,7 +127,9 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
graph/passes/compile_nodes_pass.cc \ | |||
graph/passes/constant_folding_pass.cc \ | |||
graph/passes/constant_fuse_same_pass.cc \ | |||
graph/passes/fuse_data_nodes_with_common_input_pass.cc \ | |||
graph/passes/remove_same_const_pass.cc \ | |||
graph/passes/no_data_out_const_elimination_pass.cc \ | |||
graph/passes/useless_control_out_remove_pass.cc \ | |||
graph/passes/control_trigger_pass.cc \ | |||
graph/passes/dimension_adjust_pass.cc \ | |||
@@ -65,7 +65,7 @@ class ZeroCopyOffset { | |||
// data_size of Data/Netoutput | |||
int64_t GetDataSize() const { return data_size_; } | |||
// value of *outside_addrs_ from davinci_model | |||
std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; } | |||
const std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; } | |||
// name of op | |||
std::string GetOpName() const { return op_name_; } | |||
@@ -53,6 +53,7 @@ | |||
#include "graph/passes/dimension_adjust_pass.h" | |||
#include "graph/passes/dimension_compute_pass.h" | |||
#include "graph/passes/flow_ctrl_pass.h" | |||
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||
#include "graph/passes/identity_pass.h" | |||
#include "graph/passes/input_output_connection_identify_pass.h" | |||
#include "graph/passes/iterator_op_pass.h" | |||
@@ -70,6 +71,7 @@ | |||
#include "graph/passes/remove_same_const_pass.h" | |||
#include "graph/passes/reshape_recovery_pass.h" | |||
#include "graph/passes/reshape_remove_pass.h" | |||
#include "graph/passes/no_data_out_const_elimination_pass.h" | |||
#include "graph/passes/same_transdata_breadth_fusion_pass.h" | |||
#include "graph/passes/subgraph_pass.h" | |||
#include "graph/passes/switch_data_edges_bypass.h" | |||
@@ -2104,6 +2106,24 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||
after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass)); | |||
GE_CHK_STATUS_RET( | |||
after_merge_passes.AddPass("OptimizeStage1_1::ConstantFuseSamePass", new (std::nothrow) ConstantFuseSamePass)); | |||
/* | |||
* Do CSE before FuseDataNodesWithCommonInputPass to resolve the scene in bertlarge as following: | |||
* const | |||
* / | \ | |||
* cast1 cast2 cast3 | |||
* \ | / | |||
* case | |||
* the node `const` is the fused const node after ConstantFuseSamePass | |||
* the nodes `cast1`, `cast2` and 'cast3' will be fused by CSE. | |||
* in order to eliminate hard code in FuseDataNodesWithCommonInputPass, | |||
* we do CSE before FuseDataNodesWithCommonInputPass | |||
* But it is a temp solution, this CSE will be deleted after change pass from graph pass to node pass | |||
*/ | |||
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CSEBeforeFuseDataNodesWithCommonInputPass", | |||
new (std::nothrow) CommonSubexpressionEliminationPass)); | |||
// FuseDataNodesWithCommonInputPass: fuse same data with common input in same graph | |||
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::FuseDataNodesWithCommonInputPass", | |||
new (std::nothrow) FuseDataNodesWithCommonInputPass)); | |||
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CommonSubexpressionEliminationPass", | |||
new (std::nothrow) CommonSubexpressionEliminationPass)); | |||
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::PermutePass", new (std::nothrow) PermutePass)) | |||
@@ -2226,12 +2246,14 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||
GELOGE(ret, "Run passes when OptimizeStage1_3 failed, ret:%u.", ret); | |||
return ret; | |||
} | |||
NamesToPass identity_remove_pass; | |||
GE_TIMESTAMP_START(identity_remove_pass); | |||
NamesToPass node_pass; | |||
GE_TIMESTAMP_START(node_pass); | |||
IdentityPass identity_force_pass(false); // after SwitchToStreamSwitchPass | |||
identity_remove_pass.emplace_back("IdentityPass", &identity_force_pass); | |||
ret = GEPass(compute_graph).Run(identity_remove_pass); | |||
GE_TIMESTAMP_END(identity_remove_pass, "GraphPrepare::IdentityRemovePass"); | |||
NoDataOutConstEliminationPass no_data_out_const_elimination_pass; | |||
node_pass.emplace_back("IdentityPass", &identity_force_pass); | |||
node_pass.emplace_back("NoDataOutConstEliminationPass", &no_data_out_const_elimination_pass); | |||
ret = GEPass(compute_graph).Run(node_pass); | |||
GE_TIMESTAMP_END(node_pass, "GraphPrepare::node_pass"); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Run identity remove pass for preprocess failed, ret:%u.", ret); | |||
return ret; | |||
@@ -0,0 +1,119 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||
#include <map> | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
#include <set> | |||
#include "common/ge_inner_error_codes.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "graph/utils/type_utils.h" | |||
#include "graph/utils/node_utils.h" | |||
using std::map; | |||
using std::vector; | |||
using std::set; | |||
using std::string; | |||
namespace ge { | |||
Status FuseDataNodesWithCommonInputPass::Run(ge::ComputeGraphPtr graph) { | |||
if (graph == nullptr) { | |||
GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); | |||
return GE_GRAPH_PARAM_NULLPTR; | |||
} | |||
GELOGD("FuseDataNodesWithCommonInputPass in."); | |||
// key: subgraph, value:--key: peer out anchor to parent node, --value: parent indexes to parent node | |||
map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> subgraphs_to_need_fuse_nodes_info; | |||
if (InitNeedFuseNodesInfo(graph, subgraphs_to_need_fuse_nodes_info) != SUCCESS) { | |||
GELOGE(FAILED, "InitNeedFuseNodesInfo failed."); | |||
return FAILED; | |||
} | |||
return FuseDataNodes(subgraphs_to_need_fuse_nodes_info); | |||
} | |||
Status FuseDataNodesWithCommonInputPass::InitNeedFuseNodesInfo(ComputeGraphPtr &graph, | |||
map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info) { | |||
for (const auto &subgraph : graph->GetAllSubgraphs()) { | |||
GE_CHECK_NOTNULL(subgraph); | |||
auto parent_node = subgraph->GetParentNode(); | |||
GE_CHECK_NOTNULL(parent_node); | |||
if (parent_node->GetType() == CASE || parent_node->GetType() == IF) { | |||
auto &peer_out_anchors_to_parent_indexes = subgraphs_to_need_fuse_nodes_info[subgraph]; | |||
for (const auto &in_data_anchor : parent_node->GetAllInDataAnchors()) { | |||
GE_CHECK_NOTNULL(in_data_anchor); | |||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
uint32_t parent_index = static_cast<uint32_t>(in_data_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(peer_out_anchor); | |||
peer_out_anchors_to_parent_indexes[peer_out_anchor].insert(parent_index); | |||
GELOGD("Peer node %s is the %d input of parent node %s in %s.", | |||
peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_index, parent_node->GetName().c_str(), | |||
subgraph->GetName().c_str()); | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status FuseDataNodesWithCommonInputPass::FuseDataNodes( | |||
const map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info) { | |||
for (const auto &subgraph_to_need_fuse_nodes_info : subgraphs_to_need_fuse_nodes_info) { | |||
auto subgraph = subgraph_to_need_fuse_nodes_info.first; | |||
for (const auto &peer_out_anchors_to_parent_indexes : subgraph_to_need_fuse_nodes_info.second) { | |||
if (peer_out_anchors_to_parent_indexes.second.size() <= 1) { | |||
continue; | |||
} | |||
// key: out anchor, value: data nodes with common input will be fused | |||
map<OutDataAnchorPtr, vector<NodePtr>> peer_out_anchors_to_need_fuse_nodes; | |||
for (const auto &node : subgraph->GetDirectNode()) { | |||
if (node->GetType() != DATA) { | |||
continue; | |||
} | |||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
uint32_t parent_index = 0; | |||
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||
if (peer_out_anchors_to_parent_indexes.second.count(parent_index) > 0) { | |||
peer_out_anchors_to_need_fuse_nodes[peer_out_anchors_to_parent_indexes.first].emplace_back(node); | |||
} | |||
} | |||
} | |||
for (const auto &peer_out_anchor_to_need_fuse_nodes : peer_out_anchors_to_need_fuse_nodes) { | |||
auto need_fuse_data_nodes = peer_out_anchor_to_need_fuse_nodes.second; | |||
auto first_node = need_fuse_data_nodes.at(0); | |||
for (size_t i = 1; i < need_fuse_data_nodes.size(); ++i) { | |||
auto node = need_fuse_data_nodes.at(i); | |||
GELOGI("Replace redundant data node %s by %s exist in graph: %s.", node->GetName().c_str(), | |||
first_node->GetName().c_str(), subgraph->GetName().c_str()); | |||
// the data node which can be fused has none input(both data and control in) | |||
if (GraphUtils::MoveOutCtrlEdges(node, first_node) != SUCCESS) { | |||
return FAILED; | |||
} | |||
if (GraphUtils::ReplaceNodeDataAnchors(first_node, node, {}, {0}) != SUCCESS) { | |||
return FAILED; | |||
} | |||
if (GraphUtils::RemoveNodeWithoutRelink(subgraph, node) != SUCCESS) { | |||
GELOGE(FAILED, "[%s] RemoveNodeWithoutRelink failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -0,0 +1,38 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ | |||
#define GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ | |||
#include <set> | |||
#include <map> | |||
#include <vector> | |||
#include "graph/types.h" | |||
#include "inc/graph_pass.h" | |||
namespace ge { | |||
class FuseDataNodesWithCommonInputPass : public GraphPass { | |||
public: | |||
Status Run(ge::ComputeGraphPtr graph) override; | |||
private: | |||
Status InitNeedFuseNodesInfo(ComputeGraphPtr &graph, | |||
map<ComputeGraphPtr, map<OutDataAnchorPtr, std::set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info); | |||
Status FuseDataNodes( | |||
const map<ComputeGraphPtr, map<OutDataAnchorPtr, std::set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info); | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ |
@@ -0,0 +1,36 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/passes/no_data_out_const_elimination_pass.h" | |||
namespace ge { | |||
Status NoDataOutConstEliminationPass::Run(NodePtr &node) { | |||
GE_CHECK_NOTNULL(node); | |||
GELOGD("RemoveConstWithoutDataPass running of %s.", node->GetName().c_str()); | |||
if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) { | |||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
// delete const which has no input and no output of data | |||
if (node->GetOpDesc()->GetInputsSize() == 0 && node->GetOutDataNodes().size() == 0) { | |||
GELOGI("Remove const %s.", node->GetName().c_str()); | |||
if (IsolateAndDeleteNode(node, {}) != SUCCESS) { | |||
GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -0,0 +1,31 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ | |||
#define GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ | |||
#include "graph/passes/base_pass.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "framework/common/util.h" | |||
namespace ge { | |||
class NoDataOutConstEliminationPass : public BaseNodePass { | |||
public: | |||
Status Run(ge::NodePtr &node) override; | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ |
@@ -178,6 +178,8 @@ set(COMMON_SRC_FILES | |||
"${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/replace_transshape_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/constant_fuse_same_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/no_data_out_const_elimination_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/print_op_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" | |||
"${GE_CODE_DIR}/ge/graph/passes/iterator_op_pass.cc" | |||
@@ -616,6 +618,8 @@ set(PASS_TEST_FILES | |||
"graph/passes/trans_op_depth_fusion_pass_unittest.cc" | |||
"graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc" | |||
"graph/passes/constant_folding_pass_unittest.cc" | |||
"graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc" | |||
"graph/passes/no_data_out_const_elimination_pass_unittest.cc" | |||
"graph/passes/stop_gradient_pass_unittest.cc" | |||
"graph/passes/prevent_gradient_pass_unittest.cc" | |||
"graph/passes/identity_pass_unittest.cc" | |||
@@ -0,0 +1,156 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||
#include <gtest/gtest.h> | |||
#include <string> | |||
#include <vector> | |||
#include <map> | |||
#include "inc/pass_manager.h" | |||
#include "common/ge_inner_error_codes.h" | |||
#include "graph_builder_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "graph/utils/type_utils.h" | |||
#include "graph/utils/node_utils.h" | |||
namespace ge { | |||
class UtestFuseDataNodesWithCommonInputPass : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
public: | |||
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||
for (auto i = 0; i < in_num; ++i) { | |||
op_desc->AddInputDesc(test_desc); | |||
} | |||
for (auto i = 0; i < out_num; ++i) { | |||
op_desc->AddOutputDesc(test_desc); | |||
} | |||
return graph->AddNode(op_desc); | |||
} | |||
}; | |||
/// graph with subgraph | |||
/// const | |||
/// | | | | |||
/// case | |||
/// | | |||
/// netoutput | |||
/// ... | |||
/// data0 data1 data2 | |||
/// | \ / | |||
/// conv add | |||
TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { | |||
PassManager pass_manager; | |||
pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass); | |||
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph"); | |||
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); | |||
auto parent_case = MakeNode(parent_graph, 3, 1, "parent_case", "Case"); | |||
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); | |||
GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||
parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); | |||
parent_case->GetOpDesc()->UpdateInputDesc(2, tensor_desc); | |||
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); | |||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(2)); | |||
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); | |||
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); | |||
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); | |||
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); | |||
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
auto data2 = MakeNode(parent_graph, 1, 1, "data2", "Data"); | |||
data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); | |||
(void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||
(void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); | |||
sub_graph->SetParentNode(parent_case); | |||
sub_graph->SetParentGraph(parent_graph); | |||
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); | |||
} | |||
/// graph with subgraph | |||
/// const | |||
/// / \ | |||
/// cast1 cast2 | |||
/// \ / | |||
/// case | |||
/// | | |||
/// netoutput | |||
/// ... | |||
/// data1 data2 | |||
/// \ / | |||
/// add | |||
TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) { | |||
PassManager pass_manager; | |||
pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass); | |||
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph"); | |||
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); | |||
auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast"); | |||
auto parent_cast2 = MakeNode(parent_graph, 1, 1, "parent_cast2", "Cast"); | |||
auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case"); | |||
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); | |||
GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||
parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
parent_cast2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
parent_cast2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); | |||
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast2->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(parent_cast2->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); | |||
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); | |||
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); | |||
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); | |||
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); | |||
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); | |||
(void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||
sub_graph->SetParentNode(parent_case); | |||
sub_graph->SetParentGraph(parent_graph); | |||
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); | |||
} | |||
} // namespace ge |
@@ -0,0 +1,75 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/passes/no_data_out_const_elimination_pass.h" | |||
#include <gtest/gtest.h> | |||
#include <string> | |||
#include <vector> | |||
#include <map> | |||
#include "common/ge_inner_error_codes.h" | |||
#include "graph/utils/graph_utils.h" | |||
namespace ge { | |||
class UtestNoDataOutConstEliminationPass : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
public: | |||
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||
for (auto i = 0; i < in_num; ++i) { | |||
op_desc->AddInputDesc(test_desc); | |||
} | |||
for (auto i = 0; i < out_num; ++i) { | |||
op_desc->AddOutputDesc(test_desc); | |||
} | |||
return graph->AddNode(op_desc); | |||
} | |||
}; | |||
/// graph with subgraph | |||
/// const1 | |||
/// |(control) | |||
/// const2 | |||
/// | | |||
/// output | |||
TEST_F(UtestNoDataOutConstEliminationPass, succ_graph1) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
auto const_node1 = MakeNode(graph, 0, 1, "const_node1", "Const"); | |||
auto const_node2 = MakeNode(graph, 1, 1, "const_node2", "Const"); | |||
auto output_node = MakeNode(graph, 1, 0, "output_node", "NetOutput"); | |||
GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||
const_node1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
const_node2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
const_node2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||
output_node->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||
GraphUtils::AddEdge(const_node1->GetOutControlAnchor(), const_node2->GetInControlAnchor()); | |||
GraphUtils::AddEdge(const_node2->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||
GEPass pass(graph); | |||
NamesToPass node_pass; | |||
NoDataOutConstEliminationPass no_data_out_const_elimination_pass; | |||
node_pass.emplace_back("NoDataOutConstEliminationPass", &no_data_out_const_elimination_pass); | |||
EXPECT_EQ(pass.Run(node_pass), SUCCESS); | |||
} | |||
} // namespace ge |