From 856bd59602acec737eb3fd74221d3c2137d26bc8 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Sat, 10 Apr 2021 14:42:35 +0800 Subject: [PATCH] modified: ge/graph/passes/dimension_adjust_pass.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc --- ge/graph/passes/dimension_adjust_pass.cc | 12 ++++++++++ .../passes/dimension_adjust_pass_unittest.cc | 22 ++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/ge/graph/passes/dimension_adjust_pass.cc b/ge/graph/passes/dimension_adjust_pass.cc index 61480f17..dbea8dc9 100755 --- a/ge/graph/passes/dimension_adjust_pass.cc +++ b/ge/graph/passes/dimension_adjust_pass.cc @@ -78,7 +78,12 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { GELOGE(ret, "DimensionAdjustPass compute failed"); return ret; } + // Need to handle axis_input of node like ExpandDims if (node->GetAllInDataAnchors().size() > static_cast(kRemoveInputIndex)) { + auto axis_node_out_anchor = node->GetInDataAnchor(kRemoveInputIndex)->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(axis_node_out_anchor); + auto axis_node = axis_node_out_anchor->GetOwnerNode(); + // 1.Copy control dependency of axis node ret = PassUtils::UnlinkNodeWithControlCopy(node, kRemoveInputIndex); if (ret != SUCCESS) { REPORT_CALL_ERROR("E19999", "Unlink op:%s(%s) data input:%u with control edge copy failed", @@ -86,6 +91,13 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { GELOGE(ret, "DimensionAdjustPass unlink node with control copy fail."); return ret; } + // 2.Remove const axis node without any output + if ((axis_node->GetType() == CONSTANT || axis_node->GetType() == CONSTANTOP) && + axis_node->GetOutDataNodesSize() == 0) { + ret = IsolateAndDeleteNode(axis_node, {}); + GE_CHK_GRAPH_STATUS_RET(ret, "Fail to remove node %s.", axis_node->GetName().c_str()); + GELOGI("Remove useless axis input const %s", axis_node->GetName().c_str()); + } } ret = DealWithInNodes(node); diff --git a/tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc b/tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc index 79e34a60..41ea5828 100644 --- a/tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc @@ -28,6 +28,7 @@ #include "graph/types.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" +#include "inc/kernel.h" #include "inc/kernel_factory.h" #undef protected #undef private @@ -37,11 +38,27 @@ using namespace testing; namespace ge { +class TestExpandDimKernel : public Kernel { + public: + Status Compute(const NodePtr &node_ptr) override { + return SUCCESS; + } +}; +REGISTER_KERNEL(EXPANDDIMS, TestExpandDimKernel); +class TestExpandDimKernelNotChange : public Kernel { + public: + Status Compute(const NodePtr &node_ptr) override { + return NOT_CHANGED; + } +}; + class UtestGraphPassesDimensionAdjustPass : public testing::Test { protected: void SetUp() {} - void TearDown() {} + void TearDown() { + KernelFactory::Instance().creator_map_.clear(); + } }; TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { @@ -96,8 +113,11 @@ TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { GraphUtils::AddEdge(op_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0)); std::shared_ptr pass = make_shared(); + NamesToPass names_to_passes; + EXPECT_EQ(4, graph->GetDirectNodesSize()); ge::Status ret = pass->Run(op_node); EXPECT_EQ(SUCCESS, ret); + EXPECT_EQ(2, op_node->GetOwnerComputeGraph()->GetDirectNodesSize()); } TEST_F(UtestGraphPassesDimensionAdjustPass, input_node_is_nullptr) {