Browse Source

fix infer time and mem when online infer dynamic

tags/v1.2.0
zhou_lili 3 years ago
parent
commit
6f10a03c59
12 changed files with 495 additions and 6 deletions
  1. +4
    -0
      ge/CMakeLists.txt
  2. +2
    -0
      ge/ge_inference.mk
  3. +2
    -0
      ge/ge_runner.mk
  4. +1
    -1
      ge/graph/load/new_model_manager/zero_copy_offset.h
  5. +27
    -5
      ge/graph/manager/graph_manager.cc
  6. +119
    -0
      ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc
  7. +38
    -0
      ge/graph/passes/fuse_data_nodes_with_common_input_pass.h
  8. +36
    -0
      ge/graph/passes/no_data_out_const_elimination_pass.cc
  9. +31
    -0
      ge/graph/passes/no_data_out_const_elimination_pass.h
  10. +4
    -0
      tests/ut/ge/CMakeLists.txt
  11. +156
    -0
      tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc
  12. +75
    -0
      tests/ut/ge/graph/passes/no_data_out_const_elimination_pass_unittest.cc

+ 4
- 0
ge/CMakeLists.txt View File

@@ -157,7 +157,9 @@ set(TRAIN_SRC_LIST
"graph/passes/compile_nodes_pass.cc"
"graph/passes/constant_folding_pass.cc"
"graph/passes/constant_fuse_same_pass.cc"
"graph/passes/fuse_data_nodes_with_common_input_pass.cc"
"graph/passes/remove_same_const_pass.cc"
"graph/passes/no_data_out_const_elimination_pass.cc"
"graph/passes/useless_control_out_remove_pass.cc"
"graph/passes/control_trigger_pass.cc"
"graph/passes/dimension_adjust_pass.cc"
@@ -439,6 +441,7 @@ set(INFER_SRC_LIST
"graph/passes/net_output_pass.cc"
"graph/passes/replace_transshape_pass.cc"
"graph/passes/constant_fuse_same_pass.cc"
"graph/passes/fuse_data_nodes_with_common_input_pass.cc"
"graph/passes/print_op_pass.cc"
"graph/passes/no_use_reshape_remove_pass.cc"
"graph/passes/iterator_op_pass.cc"
@@ -535,6 +538,7 @@ set(INFER_SRC_LIST
"graph/passes/addn_pass.cc"
"graph/passes/common_subexpression_elimination_pass.cc"
"graph/passes/remove_same_const_pass.cc"
"graph/passes/no_data_out_const_elimination_pass.cc"
"graph/passes/useless_control_out_remove_pass.cc"
"graph/passes/transop_symmetry_elimination_pass.cc"
"graph/passes/save_pass.cc"


+ 2
- 0
ge/ge_inference.mk View File

@@ -103,6 +103,7 @@ OMG_HOST_SRC_FILES := \
graph/passes/net_output_pass.cc \
graph/passes/replace_transshape_pass.cc \
graph/passes/constant_fuse_same_pass.cc \
graph/passes/fuse_data_nodes_with_common_input_pass.cc \
graph/passes/print_op_pass.cc \
graph/passes/no_use_reshape_remove_pass.cc \
graph/passes/iterator_op_pass.cc \
@@ -193,6 +194,7 @@ OMG_HOST_SRC_FILES := \
graph/passes/cond_pass.cc \
graph/passes/cond_remove_pass.cc \
graph/passes/remove_same_const_pass.cc \
graph/passes/no_data_out_const_elimination_pass.cc \
graph/passes/useless_control_out_remove_pass.cc \
graph/passes/for_pass.cc \
graph/passes/enter_pass.cc \


+ 2
- 0
ge/ge_runner.mk View File

@@ -127,7 +127,9 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/compile_nodes_pass.cc \
graph/passes/constant_folding_pass.cc \
graph/passes/constant_fuse_same_pass.cc \
graph/passes/fuse_data_nodes_with_common_input_pass.cc \
graph/passes/remove_same_const_pass.cc \
graph/passes/no_data_out_const_elimination_pass.cc \
graph/passes/useless_control_out_remove_pass.cc \
graph/passes/control_trigger_pass.cc \
graph/passes/dimension_adjust_pass.cc \


+ 1
- 1
ge/graph/load/new_model_manager/zero_copy_offset.h View File

@@ -65,7 +65,7 @@ class ZeroCopyOffset {
// data_size of Data/Netoutput
int64_t GetDataSize() const { return data_size_; }
// value of *outside_addrs_ from davinci_model
std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; }
const std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; }
// name of op
std::string GetOpName() const { return op_name_; }



+ 27
- 5
ge/graph/manager/graph_manager.cc View File

@@ -53,6 +53,7 @@
#include "graph/passes/dimension_adjust_pass.h"
#include "graph/passes/dimension_compute_pass.h"
#include "graph/passes/flow_ctrl_pass.h"
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h"
#include "graph/passes/identity_pass.h"
#include "graph/passes/input_output_connection_identify_pass.h"
#include "graph/passes/iterator_op_pass.h"
@@ -70,6 +71,7 @@
#include "graph/passes/remove_same_const_pass.h"
#include "graph/passes/reshape_recovery_pass.h"
#include "graph/passes/reshape_remove_pass.h"
#include "graph/passes/no_data_out_const_elimination_pass.h"
#include "graph/passes/same_transdata_breadth_fusion_pass.h"
#include "graph/passes/subgraph_pass.h"
#include "graph/passes/switch_data_edges_bypass.h"
@@ -2104,6 +2106,24 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass));
GE_CHK_STATUS_RET(
after_merge_passes.AddPass("OptimizeStage1_1::ConstantFuseSamePass", new (std::nothrow) ConstantFuseSamePass));
/*
* Do CSE before FuseDataNodesWithCommonInputPass to resolve the scene in bertlarge as following:
* const
* / | \
* cast1 cast2 cast3
* \ | /
* case
* the node `const` is the fused const node after ConstantFuseSamePass
* the nodes `cast1`, `cast2` and 'cast3' will be fused by CSE.
* in order to eliminate hard code in FuseDataNodesWithCommonInputPass,
* we do CSE before FuseDataNodesWithCommonInputPass
* But it is a temp solution, this CSE will be deleted after change pass from graph pass to node pass
*/
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CSEBeforeFuseDataNodesWithCommonInputPass",
new (std::nothrow) CommonSubexpressionEliminationPass));
// FuseDataNodesWithCommonInputPass: fuse same data with common input in same graph
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::FuseDataNodesWithCommonInputPass",
new (std::nothrow) FuseDataNodesWithCommonInputPass));
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CommonSubexpressionEliminationPass",
new (std::nothrow) CommonSubexpressionEliminationPass));
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::PermutePass", new (std::nothrow) PermutePass))
@@ -2226,12 +2246,14 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
GELOGE(ret, "Run passes when OptimizeStage1_3 failed, ret:%u.", ret);
return ret;
}
NamesToPass identity_remove_pass;
GE_TIMESTAMP_START(identity_remove_pass);
NamesToPass node_pass;
GE_TIMESTAMP_START(node_pass);
IdentityPass identity_force_pass(false); // after SwitchToStreamSwitchPass
identity_remove_pass.emplace_back("IdentityPass", &identity_force_pass);
ret = GEPass(compute_graph).Run(identity_remove_pass);
GE_TIMESTAMP_END(identity_remove_pass, "GraphPrepare::IdentityRemovePass");
NoDataOutConstEliminationPass no_data_out_const_elimination_pass;
node_pass.emplace_back("IdentityPass", &identity_force_pass);
node_pass.emplace_back("NoDataOutConstEliminationPass", &no_data_out_const_elimination_pass);
ret = GEPass(compute_graph).Run(node_pass);
GE_TIMESTAMP_END(node_pass, "GraphPrepare::node_pass");
if (ret != SUCCESS) {
GELOGE(ret, "Run identity remove pass for preprocess failed, ret:%u.", ret);
return ret;


+ 119
- 0
ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc View File

@@ -0,0 +1,119 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/fuse_data_nodes_with_common_input_pass.h"

#include <map>
#include <memory>
#include <string>
#include <vector>
#include <set>
#include "common/ge_inner_error_codes.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/utils/node_utils.h"

using std::map;
using std::vector;
using std::set;
using std::string;

namespace ge {
Status FuseDataNodesWithCommonInputPass::Run(ge::ComputeGraphPtr graph) {
if (graph == nullptr) {
GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null.");
return GE_GRAPH_PARAM_NULLPTR;
}
GELOGD("FuseDataNodesWithCommonInputPass in.");
// key: subgraph, value:--key: peer out anchor to parent node, --value: parent indexes to parent node
map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> subgraphs_to_need_fuse_nodes_info;
if (InitNeedFuseNodesInfo(graph, subgraphs_to_need_fuse_nodes_info) != SUCCESS) {
GELOGE(FAILED, "InitNeedFuseNodesInfo failed.");
return FAILED;
}
return FuseDataNodes(subgraphs_to_need_fuse_nodes_info);
}

Status FuseDataNodesWithCommonInputPass::InitNeedFuseNodesInfo(ComputeGraphPtr &graph,
map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info) {
for (const auto &subgraph : graph->GetAllSubgraphs()) {
GE_CHECK_NOTNULL(subgraph);
auto parent_node = subgraph->GetParentNode();
GE_CHECK_NOTNULL(parent_node);
if (parent_node->GetType() == CASE || parent_node->GetType() == IF) {
auto &peer_out_anchors_to_parent_indexes = subgraphs_to_need_fuse_nodes_info[subgraph];
for (const auto &in_data_anchor : parent_node->GetAllInDataAnchors()) {
GE_CHECK_NOTNULL(in_data_anchor);
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
uint32_t parent_index = static_cast<uint32_t>(in_data_anchor->GetIdx());
GE_CHECK_NOTNULL(peer_out_anchor);
peer_out_anchors_to_parent_indexes[peer_out_anchor].insert(parent_index);
GELOGD("Peer node %s is the %d input of parent node %s in %s.",
peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_index, parent_node->GetName().c_str(),
subgraph->GetName().c_str());
}
}
}
return SUCCESS;
}

Status FuseDataNodesWithCommonInputPass::FuseDataNodes(
const map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info) {
for (const auto &subgraph_to_need_fuse_nodes_info : subgraphs_to_need_fuse_nodes_info) {
auto subgraph = subgraph_to_need_fuse_nodes_info.first;
for (const auto &peer_out_anchors_to_parent_indexes : subgraph_to_need_fuse_nodes_info.second) {
if (peer_out_anchors_to_parent_indexes.second.size() <= 1) {
continue;
}
// key: out anchor, value: data nodes with common input will be fused
map<OutDataAnchorPtr, vector<NodePtr>> peer_out_anchors_to_need_fuse_nodes;
for (const auto &node : subgraph->GetDirectNode()) {
if (node->GetType() != DATA) {
continue;
}
GE_CHECK_NOTNULL(node->GetOpDesc());
uint32_t parent_index = 0;
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
if (peer_out_anchors_to_parent_indexes.second.count(parent_index) > 0) {
peer_out_anchors_to_need_fuse_nodes[peer_out_anchors_to_parent_indexes.first].emplace_back(node);
}
}
}
for (const auto &peer_out_anchor_to_need_fuse_nodes : peer_out_anchors_to_need_fuse_nodes) {
auto need_fuse_data_nodes = peer_out_anchor_to_need_fuse_nodes.second;
auto first_node = need_fuse_data_nodes.at(0);
for (size_t i = 1; i < need_fuse_data_nodes.size(); ++i) {
auto node = need_fuse_data_nodes.at(i);
GELOGI("Replace redundant data node %s by %s exist in graph: %s.", node->GetName().c_str(),
first_node->GetName().c_str(), subgraph->GetName().c_str());
// the data node which can be fused has none input(both data and control in)
if (GraphUtils::MoveOutCtrlEdges(node, first_node) != SUCCESS) {
return FAILED;
}
if (GraphUtils::ReplaceNodeDataAnchors(first_node, node, {}, {0}) != SUCCESS) {
return FAILED;
}
if (GraphUtils::RemoveNodeWithoutRelink(subgraph, node) != SUCCESS) {
GELOGE(FAILED, "[%s] RemoveNodeWithoutRelink failed.", node->GetName().c_str());
return FAILED;
}
}
}
}
}
return SUCCESS;
}
} // namespace ge

+ 38
- 0
ge/graph/passes/fuse_data_nodes_with_common_input_pass.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_
#define GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_

#include <set>
#include <map>
#include <vector>
#include "graph/types.h"
#include "inc/graph_pass.h"

namespace ge {
class FuseDataNodesWithCommonInputPass : public GraphPass {
public:
Status Run(ge::ComputeGraphPtr graph) override;

private:
Status InitNeedFuseNodesInfo(ComputeGraphPtr &graph,
map<ComputeGraphPtr, map<OutDataAnchorPtr, std::set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info);
Status FuseDataNodes(
const map<ComputeGraphPtr, map<OutDataAnchorPtr, std::set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_

+ 36
- 0
ge/graph/passes/no_data_out_const_elimination_pass.cc View File

@@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/no_data_out_const_elimination_pass.h"

namespace ge {
Status NoDataOutConstEliminationPass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);
GELOGD("RemoveConstWithoutDataPass running of %s.", node->GetName().c_str());
if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) {
GE_CHECK_NOTNULL(node->GetOpDesc());
// delete const which has no input and no output of data
if (node->GetOpDesc()->GetInputsSize() == 0 && node->GetOutDataNodes().size() == 0) {
GELOGI("Remove const %s.", node->GetName().c_str());
if (IsolateAndDeleteNode(node, {}) != SUCCESS) {
GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", node->GetName().c_str());
return FAILED;
}
}
}
return SUCCESS;
}
} // namespace ge

+ 31
- 0
ge/graph/passes/no_data_out_const_elimination_pass.h View File

@@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_
#define GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_

#include "graph/passes/base_pass.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/util.h"

namespace ge {
class NoDataOutConstEliminationPass : public BaseNodePass {
public:
Status Run(ge::NodePtr &node) override;
};
} // namespace ge

#endif // GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_

+ 4
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -178,6 +178,8 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/net_output_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/replace_transshape_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/constant_fuse_same_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/fuse_data_nodes_with_common_input_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/no_data_out_const_elimination_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/print_op_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/no_use_reshape_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/iterator_op_pass.cc"
@@ -616,6 +618,8 @@ set(PASS_TEST_FILES
"graph/passes/trans_op_depth_fusion_pass_unittest.cc"
"graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc"
"graph/passes/constant_folding_pass_unittest.cc"
"graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc"
"graph/passes/no_data_out_const_elimination_pass_unittest.cc"
"graph/passes/stop_gradient_pass_unittest.cc"
"graph/passes/prevent_gradient_pass_unittest.cc"
"graph/passes/identity_pass_unittest.cc"


+ 156
- 0
tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc View File

@@ -0,0 +1,156 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/fuse_data_nodes_with_common_input_pass.h"

#include <gtest/gtest.h>
#include <string>
#include <vector>
#include <map>

#include "inc/pass_manager.h"
#include "common/ge_inner_error_codes.h"
#include "graph_builder_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/utils/node_utils.h"

namespace ge {

class UtestFuseDataNodesWithCommonInputPass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}

public:
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) {
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT);
auto op_desc = std::make_shared<OpDesc>(name, type);
for (auto i = 0; i < in_num; ++i) {
op_desc->AddInputDesc(test_desc);
}
for (auto i = 0; i < out_num; ++i) {
op_desc->AddOutputDesc(test_desc);
}
return graph->AddNode(op_desc);
}
};

/// graph with subgraph
/// const
/// | | |
/// case
/// |
/// netoutput
/// ...
/// data0 data1 data2
/// | \ /
/// conv add
TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) {
PassManager pass_manager;
pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass);
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph");
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const");
auto parent_case = MakeNode(parent_graph, 3, 1, "parent_case", "Case");
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput");

GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT);

parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc);
parent_case->GetOpDesc()->UpdateInputDesc(2, tensor_desc);
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);

GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0));
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1));
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(2));
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0));

ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data");
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data");
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
auto data2 = MakeNode(parent_graph, 1, 1, "data2", "Data");
data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0);
(void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1);
(void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2);

sub_graph->SetParentNode(parent_case);
sub_graph->SetParentGraph(parent_graph);
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS);
}

/// graph with subgraph
/// const
/// / \
/// cast1 cast2
/// \ /
/// case
/// |
/// netoutput
/// ...
/// data1 data2
/// \ /
/// add
TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) {
PassManager pass_manager;
pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass);
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph");
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const");
auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast");
auto parent_cast2 = MakeNode(parent_graph, 1, 1, "parent_cast2", "Cast");
auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case");
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput");

GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT);

parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
parent_cast2->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
parent_cast2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc);
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);

GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0));
GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0));
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast2->GetInDataAnchor(0));
GraphUtils::AddEdge(parent_cast2->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1));
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0));

ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data");
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data");
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0);
(void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1);

sub_graph->SetParentNode(parent_case);
sub_graph->SetParentGraph(parent_graph);
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS);
}
} // namespace ge

+ 75
- 0
tests/ut/ge/graph/passes/no_data_out_const_elimination_pass_unittest.cc View File

@@ -0,0 +1,75 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/no_data_out_const_elimination_pass.h"

#include <gtest/gtest.h>
#include <string>
#include <vector>
#include <map>

#include "common/ge_inner_error_codes.h"
#include "graph/utils/graph_utils.h"

namespace ge {

class UtestNoDataOutConstEliminationPass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}

public:
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) {
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT);
auto op_desc = std::make_shared<OpDesc>(name, type);
for (auto i = 0; i < in_num; ++i) {
op_desc->AddInputDesc(test_desc);
}
for (auto i = 0; i < out_num; ++i) {
op_desc->AddOutputDesc(test_desc);
}
return graph->AddNode(op_desc);
}
};

/// graph with subgraph
/// const1
/// |(control)
/// const2
/// |
/// output
TEST_F(UtestNoDataOutConstEliminationPass, succ_graph1) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
auto const_node1 = MakeNode(graph, 0, 1, "const_node1", "Const");
auto const_node2 = MakeNode(graph, 1, 1, "const_node2", "Const");
auto output_node = MakeNode(graph, 1, 0, "output_node", "NetOutput");
GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT);

const_node1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
const_node2->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
const_node2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
output_node->GetOpDesc()->UpdateInputDesc(0, tensor_desc);

GraphUtils::AddEdge(const_node1->GetOutControlAnchor(), const_node2->GetInControlAnchor());
GraphUtils::AddEdge(const_node2->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));

GEPass pass(graph);
NamesToPass node_pass;
NoDataOutConstEliminationPass no_data_out_const_elimination_pass;
node_pass.emplace_back("NoDataOutConstEliminationPass", &no_data_out_const_elimination_pass);
EXPECT_EQ(pass.Run(node_pass), SUCCESS);
}
} // namespace ge

Loading…
Cancel
Save