@@ -157,7 +157,9 @@ set(TRAIN_SRC_LIST | |||||
"graph/passes/compile_nodes_pass.cc" | "graph/passes/compile_nodes_pass.cc" | ||||
"graph/passes/constant_folding_pass.cc" | "graph/passes/constant_folding_pass.cc" | ||||
"graph/passes/constant_fuse_same_pass.cc" | "graph/passes/constant_fuse_same_pass.cc" | ||||
"graph/passes/fuse_data_nodes_with_common_input_pass.cc" | |||||
"graph/passes/remove_same_const_pass.cc" | "graph/passes/remove_same_const_pass.cc" | ||||
"graph/passes/no_data_out_const_elimination_pass.cc" | |||||
"graph/passes/useless_control_out_remove_pass.cc" | "graph/passes/useless_control_out_remove_pass.cc" | ||||
"graph/passes/control_trigger_pass.cc" | "graph/passes/control_trigger_pass.cc" | ||||
"graph/passes/dimension_adjust_pass.cc" | "graph/passes/dimension_adjust_pass.cc" | ||||
@@ -439,6 +441,7 @@ set(INFER_SRC_LIST | |||||
"graph/passes/net_output_pass.cc" | "graph/passes/net_output_pass.cc" | ||||
"graph/passes/replace_transshape_pass.cc" | "graph/passes/replace_transshape_pass.cc" | ||||
"graph/passes/constant_fuse_same_pass.cc" | "graph/passes/constant_fuse_same_pass.cc" | ||||
"graph/passes/fuse_data_nodes_with_common_input_pass.cc" | |||||
"graph/passes/print_op_pass.cc" | "graph/passes/print_op_pass.cc" | ||||
"graph/passes/no_use_reshape_remove_pass.cc" | "graph/passes/no_use_reshape_remove_pass.cc" | ||||
"graph/passes/iterator_op_pass.cc" | "graph/passes/iterator_op_pass.cc" | ||||
@@ -535,6 +538,7 @@ set(INFER_SRC_LIST | |||||
"graph/passes/addn_pass.cc" | "graph/passes/addn_pass.cc" | ||||
"graph/passes/common_subexpression_elimination_pass.cc" | "graph/passes/common_subexpression_elimination_pass.cc" | ||||
"graph/passes/remove_same_const_pass.cc" | "graph/passes/remove_same_const_pass.cc" | ||||
"graph/passes/no_data_out_const_elimination_pass.cc" | |||||
"graph/passes/useless_control_out_remove_pass.cc" | "graph/passes/useless_control_out_remove_pass.cc" | ||||
"graph/passes/transop_symmetry_elimination_pass.cc" | "graph/passes/transop_symmetry_elimination_pass.cc" | ||||
"graph/passes/save_pass.cc" | "graph/passes/save_pass.cc" | ||||
@@ -103,6 +103,7 @@ OMG_HOST_SRC_FILES := \ | |||||
graph/passes/net_output_pass.cc \ | graph/passes/net_output_pass.cc \ | ||||
graph/passes/replace_transshape_pass.cc \ | graph/passes/replace_transshape_pass.cc \ | ||||
graph/passes/constant_fuse_same_pass.cc \ | graph/passes/constant_fuse_same_pass.cc \ | ||||
graph/passes/fuse_data_nodes_with_common_input_pass.cc \ | |||||
graph/passes/print_op_pass.cc \ | graph/passes/print_op_pass.cc \ | ||||
graph/passes/no_use_reshape_remove_pass.cc \ | graph/passes/no_use_reshape_remove_pass.cc \ | ||||
graph/passes/iterator_op_pass.cc \ | graph/passes/iterator_op_pass.cc \ | ||||
@@ -193,6 +194,7 @@ OMG_HOST_SRC_FILES := \ | |||||
graph/passes/cond_pass.cc \ | graph/passes/cond_pass.cc \ | ||||
graph/passes/cond_remove_pass.cc \ | graph/passes/cond_remove_pass.cc \ | ||||
graph/passes/remove_same_const_pass.cc \ | graph/passes/remove_same_const_pass.cc \ | ||||
graph/passes/no_data_out_const_elimination_pass.cc \ | |||||
graph/passes/useless_control_out_remove_pass.cc \ | graph/passes/useless_control_out_remove_pass.cc \ | ||||
graph/passes/for_pass.cc \ | graph/passes/for_pass.cc \ | ||||
graph/passes/enter_pass.cc \ | graph/passes/enter_pass.cc \ | ||||
@@ -127,7 +127,9 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
graph/passes/compile_nodes_pass.cc \ | graph/passes/compile_nodes_pass.cc \ | ||||
graph/passes/constant_folding_pass.cc \ | graph/passes/constant_folding_pass.cc \ | ||||
graph/passes/constant_fuse_same_pass.cc \ | graph/passes/constant_fuse_same_pass.cc \ | ||||
graph/passes/fuse_data_nodes_with_common_input_pass.cc \ | |||||
graph/passes/remove_same_const_pass.cc \ | graph/passes/remove_same_const_pass.cc \ | ||||
graph/passes/no_data_out_const_elimination_pass.cc \ | |||||
graph/passes/useless_control_out_remove_pass.cc \ | graph/passes/useless_control_out_remove_pass.cc \ | ||||
graph/passes/control_trigger_pass.cc \ | graph/passes/control_trigger_pass.cc \ | ||||
graph/passes/dimension_adjust_pass.cc \ | graph/passes/dimension_adjust_pass.cc \ | ||||
@@ -65,7 +65,7 @@ class ZeroCopyOffset { | |||||
// data_size of Data/Netoutput | // data_size of Data/Netoutput | ||||
int64_t GetDataSize() const { return data_size_; } | int64_t GetDataSize() const { return data_size_; } | ||||
// value of *outside_addrs_ from davinci_model | // value of *outside_addrs_ from davinci_model | ||||
std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; } | |||||
const std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; } | |||||
// name of op | // name of op | ||||
std::string GetOpName() const { return op_name_; } | std::string GetOpName() const { return op_name_; } | ||||
@@ -53,6 +53,7 @@ | |||||
#include "graph/passes/dimension_adjust_pass.h" | #include "graph/passes/dimension_adjust_pass.h" | ||||
#include "graph/passes/dimension_compute_pass.h" | #include "graph/passes/dimension_compute_pass.h" | ||||
#include "graph/passes/flow_ctrl_pass.h" | #include "graph/passes/flow_ctrl_pass.h" | ||||
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||||
#include "graph/passes/identity_pass.h" | #include "graph/passes/identity_pass.h" | ||||
#include "graph/passes/input_output_connection_identify_pass.h" | #include "graph/passes/input_output_connection_identify_pass.h" | ||||
#include "graph/passes/iterator_op_pass.h" | #include "graph/passes/iterator_op_pass.h" | ||||
@@ -70,6 +71,7 @@ | |||||
#include "graph/passes/remove_same_const_pass.h" | #include "graph/passes/remove_same_const_pass.h" | ||||
#include "graph/passes/reshape_recovery_pass.h" | #include "graph/passes/reshape_recovery_pass.h" | ||||
#include "graph/passes/reshape_remove_pass.h" | #include "graph/passes/reshape_remove_pass.h" | ||||
#include "graph/passes/no_data_out_const_elimination_pass.h" | |||||
#include "graph/passes/same_transdata_breadth_fusion_pass.h" | #include "graph/passes/same_transdata_breadth_fusion_pass.h" | ||||
#include "graph/passes/subgraph_pass.h" | #include "graph/passes/subgraph_pass.h" | ||||
#include "graph/passes/switch_data_edges_bypass.h" | #include "graph/passes/switch_data_edges_bypass.h" | ||||
@@ -2104,6 +2106,24 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass)); | after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass)); | ||||
GE_CHK_STATUS_RET( | GE_CHK_STATUS_RET( | ||||
after_merge_passes.AddPass("OptimizeStage1_1::ConstantFuseSamePass", new (std::nothrow) ConstantFuseSamePass)); | after_merge_passes.AddPass("OptimizeStage1_1::ConstantFuseSamePass", new (std::nothrow) ConstantFuseSamePass)); | ||||
/* | |||||
* Do CSE before FuseDataNodesWithCommonInputPass to resolve the scene in bertlarge as following: | |||||
* const | |||||
* / | \ | |||||
* cast1 cast2 cast3 | |||||
* \ | / | |||||
* case | |||||
* the node `const` is the fused const node after ConstantFuseSamePass | |||||
* the nodes `cast1`, `cast2` and 'cast3' will be fused by CSE. | |||||
* in order to eliminate hard code in FuseDataNodesWithCommonInputPass, | |||||
* we do CSE before FuseDataNodesWithCommonInputPass | |||||
* But it is a temp solution, this CSE will be deleted after change pass from graph pass to node pass | |||||
*/ | |||||
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CSEBeforeFuseDataNodesWithCommonInputPass", | |||||
new (std::nothrow) CommonSubexpressionEliminationPass)); | |||||
// FuseDataNodesWithCommonInputPass: fuse same data with common input in same graph | |||||
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::FuseDataNodesWithCommonInputPass", | |||||
new (std::nothrow) FuseDataNodesWithCommonInputPass)); | |||||
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CommonSubexpressionEliminationPass", | GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CommonSubexpressionEliminationPass", | ||||
new (std::nothrow) CommonSubexpressionEliminationPass)); | new (std::nothrow) CommonSubexpressionEliminationPass)); | ||||
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::PermutePass", new (std::nothrow) PermutePass)) | GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::PermutePass", new (std::nothrow) PermutePass)) | ||||
@@ -2226,12 +2246,14 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
GELOGE(ret, "Run passes when OptimizeStage1_3 failed, ret:%u.", ret); | GELOGE(ret, "Run passes when OptimizeStage1_3 failed, ret:%u.", ret); | ||||
return ret; | return ret; | ||||
} | } | ||||
NamesToPass identity_remove_pass; | |||||
GE_TIMESTAMP_START(identity_remove_pass); | |||||
NamesToPass node_pass; | |||||
GE_TIMESTAMP_START(node_pass); | |||||
IdentityPass identity_force_pass(false); // after SwitchToStreamSwitchPass | IdentityPass identity_force_pass(false); // after SwitchToStreamSwitchPass | ||||
identity_remove_pass.emplace_back("IdentityPass", &identity_force_pass); | |||||
ret = GEPass(compute_graph).Run(identity_remove_pass); | |||||
GE_TIMESTAMP_END(identity_remove_pass, "GraphPrepare::IdentityRemovePass"); | |||||
NoDataOutConstEliminationPass no_data_out_const_elimination_pass; | |||||
node_pass.emplace_back("IdentityPass", &identity_force_pass); | |||||
node_pass.emplace_back("NoDataOutConstEliminationPass", &no_data_out_const_elimination_pass); | |||||
ret = GEPass(compute_graph).Run(node_pass); | |||||
GE_TIMESTAMP_END(node_pass, "GraphPrepare::node_pass"); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Run identity remove pass for preprocess failed, ret:%u.", ret); | GELOGE(ret, "Run identity remove pass for preprocess failed, ret:%u.", ret); | ||||
return ret; | return ret; | ||||
@@ -0,0 +1,119 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||||
#include <map> | |||||
#include <memory> | |||||
#include <string> | |||||
#include <vector> | |||||
#include <set> | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "graph/utils/op_desc_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "graph/utils/node_utils.h" | |||||
using std::map; | |||||
using std::vector; | |||||
using std::set; | |||||
using std::string; | |||||
namespace ge { | |||||
Status FuseDataNodesWithCommonInputPass::Run(ge::ComputeGraphPtr graph) { | |||||
if (graph == nullptr) { | |||||
GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); | |||||
return GE_GRAPH_PARAM_NULLPTR; | |||||
} | |||||
GELOGD("FuseDataNodesWithCommonInputPass in."); | |||||
// key: subgraph, value:--key: peer out anchor to parent node, --value: parent indexes to parent node | |||||
map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> subgraphs_to_need_fuse_nodes_info; | |||||
if (InitNeedFuseNodesInfo(graph, subgraphs_to_need_fuse_nodes_info) != SUCCESS) { | |||||
GELOGE(FAILED, "InitNeedFuseNodesInfo failed."); | |||||
return FAILED; | |||||
} | |||||
return FuseDataNodes(subgraphs_to_need_fuse_nodes_info); | |||||
} | |||||
Status FuseDataNodesWithCommonInputPass::InitNeedFuseNodesInfo(ComputeGraphPtr &graph, | |||||
map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info) { | |||||
for (const auto &subgraph : graph->GetAllSubgraphs()) { | |||||
GE_CHECK_NOTNULL(subgraph); | |||||
auto parent_node = subgraph->GetParentNode(); | |||||
GE_CHECK_NOTNULL(parent_node); | |||||
if (parent_node->GetType() == CASE || parent_node->GetType() == IF) { | |||||
auto &peer_out_anchors_to_parent_indexes = subgraphs_to_need_fuse_nodes_info[subgraph]; | |||||
for (const auto &in_data_anchor : parent_node->GetAllInDataAnchors()) { | |||||
GE_CHECK_NOTNULL(in_data_anchor); | |||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
uint32_t parent_index = static_cast<uint32_t>(in_data_anchor->GetIdx()); | |||||
GE_CHECK_NOTNULL(peer_out_anchor); | |||||
peer_out_anchors_to_parent_indexes[peer_out_anchor].insert(parent_index); | |||||
GELOGD("Peer node %s is the %d input of parent node %s in %s.", | |||||
peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_index, parent_node->GetName().c_str(), | |||||
subgraph->GetName().c_str()); | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status FuseDataNodesWithCommonInputPass::FuseDataNodes( | |||||
const map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info) { | |||||
for (const auto &subgraph_to_need_fuse_nodes_info : subgraphs_to_need_fuse_nodes_info) { | |||||
auto subgraph = subgraph_to_need_fuse_nodes_info.first; | |||||
for (const auto &peer_out_anchors_to_parent_indexes : subgraph_to_need_fuse_nodes_info.second) { | |||||
if (peer_out_anchors_to_parent_indexes.second.size() <= 1) { | |||||
continue; | |||||
} | |||||
// key: out anchor, value: data nodes with common input will be fused | |||||
map<OutDataAnchorPtr, vector<NodePtr>> peer_out_anchors_to_need_fuse_nodes; | |||||
for (const auto &node : subgraph->GetDirectNode()) { | |||||
if (node->GetType() != DATA) { | |||||
continue; | |||||
} | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
uint32_t parent_index = 0; | |||||
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
if (peer_out_anchors_to_parent_indexes.second.count(parent_index) > 0) { | |||||
peer_out_anchors_to_need_fuse_nodes[peer_out_anchors_to_parent_indexes.first].emplace_back(node); | |||||
} | |||||
} | |||||
} | |||||
for (const auto &peer_out_anchor_to_need_fuse_nodes : peer_out_anchors_to_need_fuse_nodes) { | |||||
auto need_fuse_data_nodes = peer_out_anchor_to_need_fuse_nodes.second; | |||||
auto first_node = need_fuse_data_nodes.at(0); | |||||
for (size_t i = 1; i < need_fuse_data_nodes.size(); ++i) { | |||||
auto node = need_fuse_data_nodes.at(i); | |||||
GELOGI("Replace redundant data node %s by %s exist in graph: %s.", node->GetName().c_str(), | |||||
first_node->GetName().c_str(), subgraph->GetName().c_str()); | |||||
// the data node which can be fused has none input(both data and control in) | |||||
if (GraphUtils::MoveOutCtrlEdges(node, first_node) != SUCCESS) { | |||||
return FAILED; | |||||
} | |||||
if (GraphUtils::ReplaceNodeDataAnchors(first_node, node, {}, {0}) != SUCCESS) { | |||||
return FAILED; | |||||
} | |||||
if (GraphUtils::RemoveNodeWithoutRelink(subgraph, node) != SUCCESS) { | |||||
GELOGE(FAILED, "[%s] RemoveNodeWithoutRelink failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,38 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ | |||||
#define GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ | |||||
#include <set> | |||||
#include <map> | |||||
#include <vector> | |||||
#include "graph/types.h" | |||||
#include "inc/graph_pass.h" | |||||
namespace ge { | |||||
class FuseDataNodesWithCommonInputPass : public GraphPass { | |||||
public: | |||||
Status Run(ge::ComputeGraphPtr graph) override; | |||||
private: | |||||
Status InitNeedFuseNodesInfo(ComputeGraphPtr &graph, | |||||
map<ComputeGraphPtr, map<OutDataAnchorPtr, std::set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info); | |||||
Status FuseDataNodes( | |||||
const map<ComputeGraphPtr, map<OutDataAnchorPtr, std::set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info); | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_ |
@@ -0,0 +1,36 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/no_data_out_const_elimination_pass.h" | |||||
namespace ge { | |||||
Status NoDataOutConstEliminationPass::Run(NodePtr &node) { | |||||
GE_CHECK_NOTNULL(node); | |||||
GELOGD("RemoveConstWithoutDataPass running of %s.", node->GetName().c_str()); | |||||
if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) { | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
// delete const which has no input and no output of data | |||||
if (node->GetOpDesc()->GetInputsSize() == 0 && node->GetOutDataNodes().size() == 0) { | |||||
GELOGI("Remove const %s.", node->GetName().c_str()); | |||||
if (IsolateAndDeleteNode(node, {}) != SUCCESS) { | |||||
GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,31 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ | |||||
#define GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ | |||||
#include "graph/passes/base_pass.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "framework/common/util.h" | |||||
namespace ge { | |||||
class NoDataOutConstEliminationPass : public BaseNodePass { | |||||
public: | |||||
Status Run(ge::NodePtr &node) override; | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_ |
@@ -178,6 +178,8 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/replace_transshape_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/replace_transshape_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/constant_fuse_same_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/constant_fuse_same_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/no_data_out_const_elimination_pass.cc" | |||||
"${GE_CODE_DIR}/ge/graph/passes/print_op_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/print_op_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/iterator_op_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/iterator_op_pass.cc" | ||||
@@ -616,6 +618,8 @@ set(PASS_TEST_FILES | |||||
"graph/passes/trans_op_depth_fusion_pass_unittest.cc" | "graph/passes/trans_op_depth_fusion_pass_unittest.cc" | ||||
"graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc" | "graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc" | ||||
"graph/passes/constant_folding_pass_unittest.cc" | "graph/passes/constant_folding_pass_unittest.cc" | ||||
"graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc" | |||||
"graph/passes/no_data_out_const_elimination_pass_unittest.cc" | |||||
"graph/passes/stop_gradient_pass_unittest.cc" | "graph/passes/stop_gradient_pass_unittest.cc" | ||||
"graph/passes/prevent_gradient_pass_unittest.cc" | "graph/passes/prevent_gradient_pass_unittest.cc" | ||||
"graph/passes/identity_pass_unittest.cc" | "graph/passes/identity_pass_unittest.cc" | ||||
@@ -0,0 +1,156 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h" | |||||
#include <gtest/gtest.h> | |||||
#include <string> | |||||
#include <vector> | |||||
#include <map> | |||||
#include "inc/pass_manager.h" | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "graph_builder_utils.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/op_desc_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "graph/utils/node_utils.h" | |||||
namespace ge { | |||||
class UtestFuseDataNodesWithCommonInputPass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
public: | |||||
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||||
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
for (auto i = 0; i < in_num; ++i) { | |||||
op_desc->AddInputDesc(test_desc); | |||||
} | |||||
for (auto i = 0; i < out_num; ++i) { | |||||
op_desc->AddOutputDesc(test_desc); | |||||
} | |||||
return graph->AddNode(op_desc); | |||||
} | |||||
}; | |||||
/// graph with subgraph | |||||
/// const | |||||
/// | | | | |||||
/// case | |||||
/// | | |||||
/// netoutput | |||||
/// ... | |||||
/// data0 data1 data2 | |||||
/// | \ / | |||||
/// conv add | |||||
TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { | |||||
PassManager pass_manager; | |||||
pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass); | |||||
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph"); | |||||
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); | |||||
auto parent_case = MakeNode(parent_graph, 3, 1, "parent_case", "Case"); | |||||
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); | |||||
GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||||
parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); | |||||
parent_case->GetOpDesc()->UpdateInputDesc(2, tensor_desc); | |||||
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(2)); | |||||
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); | |||||
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); | |||||
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); | |||||
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); | |||||
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
auto data2 = MakeNode(parent_graph, 1, 1, "data2", "Data"); | |||||
data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); | |||||
(void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||||
(void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); | |||||
sub_graph->SetParentNode(parent_case); | |||||
sub_graph->SetParentGraph(parent_graph); | |||||
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); | |||||
} | |||||
/// graph with subgraph | |||||
/// const | |||||
/// / \ | |||||
/// cast1 cast2 | |||||
/// \ / | |||||
/// case | |||||
/// | | |||||
/// netoutput | |||||
/// ... | |||||
/// data1 data2 | |||||
/// \ / | |||||
/// add | |||||
TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) { | |||||
PassManager pass_manager; | |||||
pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass); | |||||
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph"); | |||||
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const"); | |||||
auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast"); | |||||
auto parent_cast2 = MakeNode(parent_graph, 1, 1, "parent_cast2", "Cast"); | |||||
auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case"); | |||||
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput"); | |||||
GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||||
parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
parent_cast2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
parent_cast2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc); | |||||
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast2->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(parent_cast2->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0)); | |||||
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); | |||||
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data"); | |||||
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data"); | |||||
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); | |||||
(void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||||
sub_graph->SetParentNode(parent_case); | |||||
sub_graph->SetParentGraph(parent_graph); | |||||
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,75 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/no_data_out_const_elimination_pass.h" | |||||
#include <gtest/gtest.h> | |||||
#include <string> | |||||
#include <vector> | |||||
#include <map> | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
namespace ge { | |||||
class UtestNoDataOutConstEliminationPass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
public: | |||||
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||||
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
for (auto i = 0; i < in_num; ++i) { | |||||
op_desc->AddInputDesc(test_desc); | |||||
} | |||||
for (auto i = 0; i < out_num; ++i) { | |||||
op_desc->AddOutputDesc(test_desc); | |||||
} | |||||
return graph->AddNode(op_desc); | |||||
} | |||||
}; | |||||
/// graph with subgraph | |||||
/// const1 | |||||
/// |(control) | |||||
/// const2 | |||||
/// | | |||||
/// output | |||||
TEST_F(UtestNoDataOutConstEliminationPass, succ_graph1) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
auto const_node1 = MakeNode(graph, 0, 1, "const_node1", "Const"); | |||||
auto const_node2 = MakeNode(graph, 1, 1, "const_node2", "Const"); | |||||
auto output_node = MakeNode(graph, 1, 0, "output_node", "NetOutput"); | |||||
GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT); | |||||
const_node1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
const_node2->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
const_node2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc); | |||||
output_node->GetOpDesc()->UpdateInputDesc(0, tensor_desc); | |||||
GraphUtils::AddEdge(const_node1->GetOutControlAnchor(), const_node2->GetInControlAnchor()); | |||||
GraphUtils::AddEdge(const_node2->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
GEPass pass(graph); | |||||
NamesToPass node_pass; | |||||
NoDataOutConstEliminationPass no_data_out_const_elimination_pass; | |||||
node_pass.emplace_back("NoDataOutConstEliminationPass", &no_data_out_const_elimination_pass); | |||||
EXPECT_EQ(pass.Run(node_pass), SUCCESS); | |||||
} | |||||
} // namespace ge |