Browse Source

!1157 Bugfix: replace_with_empty pass do not handle DATA nodes

From: @hugo1
Reviewed-by: @xchu42,@wqtshg
Signed-off-by:
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
3a0a0c550e
3 changed files with 90 additions and 1 deletions
  1. +1
    -1
      ge/graph/passes/replace_with_empty_const_pass.cc
  2. +1
    -0
      tests/ut/ge/CMakeLists.txt
  3. +88
    -0
      tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc

+ 1
- 1
ge/graph/passes/replace_with_empty_const_pass.cc View File

@@ -33,7 +33,7 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) {
GELOGE(PARAM_INVALID, "Param [opDesc] must not be null."); GELOGE(PARAM_INVALID, "Param [opDesc] must not be null.");
return PARAM_INVALID; return PARAM_INVALID;
} }
if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) {
if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP || node->GetType() == DATA) {
GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str()); GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str());
return SUCCESS; return SUCCESS;
} }


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -689,6 +689,7 @@ set(PASS_TEST_FILES
"graph/passes/no_use_reshape_remove_pass_unittest.cc" "graph/passes/no_use_reshape_remove_pass_unittest.cc"
"graph/passes/infershape_pass_unittest.cc" "graph/passes/infershape_pass_unittest.cc"
"graph/passes/multi_batch_clone_pass_unittest.cc" "graph/passes/multi_batch_clone_pass_unittest.cc"
"graph/passes/replace_with_empty_const_pass_unittest.cc"
) )


set(KERNEL_TEST_FILES set(KERNEL_TEST_FILES


+ 88
- 0
tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc View File

@@ -0,0 +1,88 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/replace_with_empty_const_pass.h"

#include <gtest/gtest.h>
#include <set>
#include <string>

#include "graph_builder_utils.h"

namespace ge {
class UtestReplaceWithEmptyConstPass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

namespace {
/// data1 const1
/// \ /
/// add1
/// |
/// cast1(empty)
/// |
/// conv2d
ut::GraphBuilder Graph1Builder() {
ut::GraphBuilder builder = ut::GraphBuilder("g1");
auto data1 = builder.AddNode("data1", "Data", 0, 1);
auto const1 = builder.AddNode("const1", "Const", 0, 1);
auto add1 = builder.AddNode("add1", "Add", 2, 1);
auto cast1 = builder.AddNode("cast1", "Cast", 1, 1);
auto conv2d = builder.AddNode("conv2d", "Conv2D", 1, 0);

add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1,1,8,8}),FORMAT_NCHW));
add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1,1,8,8}),FORMAT_NCHW));
add1->GetOpDesc()->AddOutputDesc(GeTensorDesc(GeShape({1,1,8,8}),FORMAT_NCHW));
GeTensorDesc empty_tensor(GeShape({1,0,8,8}),FORMAT_NCHW);
cast1->GetOpDesc()->UpdateOutputDesc(0,empty_tensor);

builder.AddDataEdge(data1, 0, add1, 0);
builder.AddDataEdge(const1, 0, add1, 1);
builder.AddDataEdge(add1, 0, cast1, 0);
builder.AddDataEdge(cast1, 0, conv2d, 0);
return builder;
}
} // namespace


TEST_F(UtestReplaceWithEmptyConstPass, replace_whith_empty_const_success) {
auto builder = Graph1Builder();
auto graph = builder.GetGraph();
graph->SetSessionID(0);
ReplaceWithEmptyConstPass replace_with_empty_const_pass;

EXPECT_EQ(graph->GetDirectNodesSize(),5);
// run pass on add1, graph still has 5 nodes
auto add1 = graph->FindNode("add1");
Status ret = replace_with_empty_const_pass.Run(add1);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph->GetDirectNodesSize(),5);

// run pass on const1, graph still has 5 nodes
auto const1 = graph->FindNode("const1");
ret = replace_with_empty_const_pass.Run(const1);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph->GetDirectNodesSize(),5);

auto cast1 = graph->FindNode("cast1");
ret = replace_with_empty_const_pass.Run(cast1);
EXPECT_EQ(cast1->GetOutAllNodes().size(),0);
auto conv2d = graph->FindNode("conv2d");
EXPECT_EQ(conv2d->GetInDataNodes().at(0)->GetType(),"Const");
}
} // namespace ge

Loading…
Cancel
Save