Browse Source

skip control flow op when replace node with empty tensor

tags/v1.5.1
wangzhengjun 3 years ago
parent
commit
64b22f8c98
2 changed files with 65 additions and 0 deletions
  1. +20
    -0
      ge/graph/passes/replace_with_empty_const_pass.cc
  2. +45
    -0
      tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc

+ 20
- 0
ge/graph/passes/replace_with_empty_const_pass.cc View File

@@ -21,7 +21,23 @@
#include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"

namespace {
const std::unordered_set<std::string> kControlFlowOps = {
ge::SWITCH,
ge::REFSWITCH,
ge::MERGE,
ge::REFMERGE,
ge::ENTER,
ge::REFENTER,
ge::NEXTITERATION,
ge::REFNEXTITERATION,
ge::EXIT,
ge::REFEXIT,
ge::LOOPCOND
};
}
namespace ge {
Status ReplaceWithEmptyConstPass::Run(NodePtr &node) {
GELOGD("ReplaceWithEmptyConstPass in.");
@@ -39,6 +55,10 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) {
GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str());
return SUCCESS;
}
if (kControlFlowOps.count(NodeUtils::GetNodeType(node)) != 0) {
GELOGI("Node %s is control flow op. Ignore current pass.", node->GetName().c_str());
return SUCCESS;
}
// Node like no op, it has no output
if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) {
GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str());


+ 45
- 0
tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc View File

@@ -57,6 +57,36 @@ ut::GraphBuilder Graph1Builder() {
builder.AddDataEdge(cast1, 0, conv2d, 0);
return builder;
}

/// data1 const1
/// \ /
/// add1
/// |
/// data2 -> switch1 (empty)
/// |
/// conv2d
ut::GraphBuilder Graph2Builder() {
ut::GraphBuilder builder = ut::GraphBuilder("graph2");
auto data1 = builder.AddNode("data1", "Data", 0, 1);
auto data2 = builder.AddNode("data2", "Data", 0, 1);
auto const1 = builder.AddNode("const1", "Const", 0, 1);
auto add1 = builder.AddNode("add1", "Add", 2, 1);
auto switch1 = builder.AddNode("switch1", "Switch", 2, 1);
auto conv2d = builder.AddNode("conv2d", "Conv2D", 1, 0);

add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW));
add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW));
add1->GetOpDesc()->AddOutputDesc(GeTensorDesc(GeShape({1, 1, 8, 8}),FORMAT_NCHW));
GeTensorDesc empty_tensor(GeShape({1, 0, 8, 8}),FORMAT_NCHW);
switch1->GetOpDesc()->UpdateOutputDesc(0, empty_tensor);

builder.AddDataEdge(data1, 0, add1, 0);
builder.AddDataEdge(const1, 0, add1, 1);
builder.AddDataEdge(add1, 0, switch1, 0);
builder.AddDataEdge(data2, 0, switch1, 1);
builder.AddDataEdge(switch1, 0, conv2d, 0);
return builder;
}
} // namespace


@@ -85,4 +115,19 @@ TEST_F(UtestReplaceWithEmptyConstPass, replace_whith_empty_const_success) {
auto conv2d = graph->FindNode("conv2d");
EXPECT_EQ(conv2d->GetInDataNodes().at(0)->GetType(),"Const");
}

TEST_F(UtestReplaceWithEmptyConstPass, replace_whith_empty_switch_skip) {
auto builder = Graph2Builder();
auto graph = builder.GetGraph();
graph->SetSessionID(0);
ReplaceWithEmptyConstPass replace_with_empty_const_pass;

EXPECT_EQ(graph->GetDirectNodesSize(), 6);
// run pass on switch1, graph still has 6 nodes
auto switch1 = graph->FindNode("switch1");
EXPECT_NE(switch1, nullptr);
Status ret = replace_with_empty_const_pass.Run(switch1);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph->GetDirectNodesSize(), 6);
}
} // namespace ge

Loading…
Cancel
Save