modified: ge/graph/passes/dimension_adjust_pass.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cctags/v1.3.0
| @@ -78,7 +78,12 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { | |||||
| GELOGE(ret, "DimensionAdjustPass compute failed"); | GELOGE(ret, "DimensionAdjustPass compute failed"); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| // Need to handle axis_input of node like ExpandDims | |||||
| if (node->GetAllInDataAnchors().size() > static_cast<size_t>(kRemoveInputIndex)) { | if (node->GetAllInDataAnchors().size() > static_cast<size_t>(kRemoveInputIndex)) { | ||||
| auto axis_node_out_anchor = node->GetInDataAnchor(kRemoveInputIndex)->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(axis_node_out_anchor); | |||||
| auto axis_node = axis_node_out_anchor->GetOwnerNode(); | |||||
| // 1.Copy control dependency of axis node | |||||
| ret = PassUtils::UnlinkNodeWithControlCopy(node, kRemoveInputIndex); | ret = PassUtils::UnlinkNodeWithControlCopy(node, kRemoveInputIndex); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Unlink op:%s(%s) data input:%u with control edge copy failed", | REPORT_CALL_ERROR("E19999", "Unlink op:%s(%s) data input:%u with control edge copy failed", | ||||
| @@ -86,6 +91,13 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { | |||||
| GELOGE(ret, "DimensionAdjustPass unlink node with control copy fail."); | GELOGE(ret, "DimensionAdjustPass unlink node with control copy fail."); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| // 2.Remove const axis node without any output | |||||
| if ((axis_node->GetType() == CONSTANT || axis_node->GetType() == CONSTANTOP) && | |||||
| axis_node->GetOutDataNodesSize() == 0) { | |||||
| ret = IsolateAndDeleteNode(axis_node, {}); | |||||
| GE_CHK_GRAPH_STATUS_RET(ret, "Fail to remove node %s.", axis_node->GetName().c_str()); | |||||
| GELOGI("Remove useless axis input const %s", axis_node->GetName().c_str()); | |||||
| } | |||||
| } | } | ||||
| ret = DealWithInNodes(node); | ret = DealWithInNodes(node); | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "inc/kernel.h" | |||||
| #include "inc/kernel_factory.h" | #include "inc/kernel_factory.h" | ||||
| #undef protected | #undef protected | ||||
| #undef private | #undef private | ||||
| @@ -37,11 +38,27 @@ using namespace testing; | |||||
| namespace ge { | namespace ge { | ||||
| class TestExpandDimKernel : public Kernel { | |||||
| public: | |||||
| Status Compute(const NodePtr &node_ptr) override { | |||||
| return SUCCESS; | |||||
| } | |||||
| }; | |||||
| REGISTER_KERNEL(EXPANDDIMS, TestExpandDimKernel); | |||||
| class TestExpandDimKernelNotChange : public Kernel { | |||||
| public: | |||||
| Status Compute(const NodePtr &node_ptr) override { | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| }; | |||||
| class UtestGraphPassesDimensionAdjustPass : public testing::Test { | class UtestGraphPassesDimensionAdjustPass : public testing::Test { | ||||
| protected: | protected: | ||||
| void SetUp() {} | void SetUp() {} | ||||
| void TearDown() {} | |||||
| void TearDown() { | |||||
| KernelFactory::Instance().creator_map_.clear(); | |||||
| } | |||||
| }; | }; | ||||
| TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { | TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { | ||||
| @@ -96,8 +113,11 @@ TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { | |||||
| GraphUtils::AddEdge(op_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0)); | GraphUtils::AddEdge(op_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0)); | ||||
| std::shared_ptr<DimensionAdjustPass> pass = make_shared<DimensionAdjustPass>(); | std::shared_ptr<DimensionAdjustPass> pass = make_shared<DimensionAdjustPass>(); | ||||
| NamesToPass names_to_passes; | |||||
| EXPECT_EQ(4, graph->GetDirectNodesSize()); | |||||
| ge::Status ret = pass->Run(op_node); | ge::Status ret = pass->Run(op_node); | ||||
| EXPECT_EQ(SUCCESS, ret); | EXPECT_EQ(SUCCESS, ret); | ||||
| EXPECT_EQ(2, op_node->GetOwnerComputeGraph()->GetDirectNodesSize()); | |||||
| } | } | ||||
| TEST_F(UtestGraphPassesDimensionAdjustPass, input_node_is_nullptr) { | TEST_F(UtestGraphPassesDimensionAdjustPass, input_node_is_nullptr) { | ||||