modified: ge/graph/passes/dimension_adjust_pass.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc modified: ge/graph/passes/dimension_adjust_pass.cc modified: tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cctags/v1.3.0
@@ -78,7 +78,12 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { | |||||
GELOGE(ret, "DimensionAdjustPass compute failed"); | GELOGE(ret, "DimensionAdjustPass compute failed"); | ||||
return ret; | return ret; | ||||
} | } | ||||
// Need to handle axis_input of node like ExpandDims | |||||
if (node->GetAllInDataAnchors().size() > static_cast<size_t>(kRemoveInputIndex)) { | if (node->GetAllInDataAnchors().size() > static_cast<size_t>(kRemoveInputIndex)) { | ||||
auto axis_node_out_anchor = node->GetInDataAnchor(kRemoveInputIndex)->GetPeerOutAnchor(); | |||||
GE_CHECK_NOTNULL(axis_node_out_anchor); | |||||
auto axis_node = axis_node_out_anchor->GetOwnerNode(); | |||||
// 1.Copy control dependency of axis node | |||||
ret = PassUtils::UnlinkNodeWithControlCopy(node, kRemoveInputIndex); | ret = PassUtils::UnlinkNodeWithControlCopy(node, kRemoveInputIndex); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Unlink op:%s(%s) data input:%u with control edge copy failed", | REPORT_CALL_ERROR("E19999", "Unlink op:%s(%s) data input:%u with control edge copy failed", | ||||
@@ -86,6 +91,13 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { | |||||
GELOGE(ret, "DimensionAdjustPass unlink node with control copy fail."); | GELOGE(ret, "DimensionAdjustPass unlink node with control copy fail."); | ||||
return ret; | return ret; | ||||
} | } | ||||
// 2.Remove const axis node without any output | |||||
if ((axis_node->GetType() == CONSTANT || axis_node->GetType() == CONSTANTOP) && | |||||
axis_node->GetOutDataNodesSize() == 0) { | |||||
ret = IsolateAndDeleteNode(axis_node, {}); | |||||
GE_CHK_GRAPH_STATUS_RET(ret, "Fail to remove node %s.", axis_node->GetName().c_str()); | |||||
GELOGI("Remove useless axis input const %s", axis_node->GetName().c_str()); | |||||
} | |||||
} | } | ||||
ret = DealWithInNodes(node); | ret = DealWithInNodes(node); | ||||
@@ -28,6 +28,7 @@ | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "inc/kernel.h" | |||||
#include "inc/kernel_factory.h" | #include "inc/kernel_factory.h" | ||||
#undef protected | #undef protected | ||||
#undef private | #undef private | ||||
@@ -37,11 +38,27 @@ using namespace testing; | |||||
namespace ge { | namespace ge { | ||||
class TestExpandDimKernel : public Kernel { | |||||
public: | |||||
Status Compute(const NodePtr &node_ptr) override { | |||||
return SUCCESS; | |||||
} | |||||
}; | |||||
REGISTER_KERNEL(EXPANDDIMS, TestExpandDimKernel); | |||||
class TestExpandDimKernelNotChange : public Kernel { | |||||
public: | |||||
Status Compute(const NodePtr &node_ptr) override { | |||||
return NOT_CHANGED; | |||||
} | |||||
}; | |||||
class UtestGraphPassesDimensionAdjustPass : public testing::Test { | class UtestGraphPassesDimensionAdjustPass : public testing::Test { | ||||
protected: | protected: | ||||
void SetUp() {} | void SetUp() {} | ||||
void TearDown() {} | |||||
void TearDown() { | |||||
KernelFactory::Instance().creator_map_.clear(); | |||||
} | |||||
}; | }; | ||||
TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { | TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { | ||||
@@ -96,8 +113,11 @@ TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { | |||||
GraphUtils::AddEdge(op_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0)); | GraphUtils::AddEdge(op_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0)); | ||||
std::shared_ptr<DimensionAdjustPass> pass = make_shared<DimensionAdjustPass>(); | std::shared_ptr<DimensionAdjustPass> pass = make_shared<DimensionAdjustPass>(); | ||||
NamesToPass names_to_passes; | |||||
EXPECT_EQ(4, graph->GetDirectNodesSize()); | |||||
ge::Status ret = pass->Run(op_node); | ge::Status ret = pass->Run(op_node); | ||||
EXPECT_EQ(SUCCESS, ret); | EXPECT_EQ(SUCCESS, ret); | ||||
EXPECT_EQ(2, op_node->GetOwnerComputeGraph()->GetDirectNodesSize()); | |||||
} | } | ||||
TEST_F(UtestGraphPassesDimensionAdjustPass, input_node_is_nullptr) { | TEST_F(UtestGraphPassesDimensionAdjustPass, input_node_is_nullptr) { | ||||