Browse Source

!1497 Bugfix: dimension adjust pass remove useless const,avoid isolate const partitioned to empty sub graph

From: @hugo1
Reviewed-by: @xchu42,@wqtshg
Signed-off-by: @wqtshg
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
4372a28069
2 changed files with 33 additions and 1 deletions
  1. +12
    -0
      ge/graph/passes/dimension_adjust_pass.cc
  2. +21
    -1
      tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc

+ 12
- 0
ge/graph/passes/dimension_adjust_pass.cc View File

@@ -78,7 +78,12 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) {
GELOGE(ret, "DimensionAdjustPass compute failed");
return ret;
}
// Need to handle axis_input of node like ExpandDims
if (node->GetAllInDataAnchors().size() > static_cast<size_t>(kRemoveInputIndex)) {
auto axis_node_out_anchor = node->GetInDataAnchor(kRemoveInputIndex)->GetPeerOutAnchor();
GE_CHECK_NOTNULL(axis_node_out_anchor);
auto axis_node = axis_node_out_anchor->GetOwnerNode();
// 1.Copy control dependency of axis node
ret = PassUtils::UnlinkNodeWithControlCopy(node, kRemoveInputIndex);
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Unlink op:%s(%s) data input:%u with control edge copy failed",
@@ -86,6 +91,13 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) {
GELOGE(ret, "DimensionAdjustPass unlink node with control copy fail.");
return ret;
}
// 2.Remove const axis node without any output
if ((axis_node->GetType() == CONSTANT || axis_node->GetType() == CONSTANTOP) &&
axis_node->GetOutDataNodesSize() == 0) {
ret = IsolateAndDeleteNode(axis_node, {});
GE_CHK_GRAPH_STATUS_RET(ret, "Fail to remove node %s.", axis_node->GetName().c_str());
GELOGI("Remove useless axis input const %s", axis_node->GetName().c_str());
}
}

ret = DealWithInNodes(node);


+ 21
- 1
tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc View File

@@ -28,6 +28,7 @@
#include "graph/types.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "inc/kernel.h"
#include "inc/kernel_factory.h"
#undef protected
#undef private
@@ -37,11 +38,27 @@ using namespace testing;

namespace ge {

class TestExpandDimKernel : public Kernel {
public:
Status Compute(const NodePtr &node_ptr) override {
return SUCCESS;
}
};
REGISTER_KERNEL(EXPANDDIMS, TestExpandDimKernel);
class TestExpandDimKernelNotChange : public Kernel {
public:
Status Compute(const NodePtr &node_ptr) override {
return NOT_CHANGED;
}
};

class UtestGraphPassesDimensionAdjustPass : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
void TearDown() {
KernelFactory::Instance().creator_map_.clear();
}
};

TEST_F(UtestGraphPassesDimensionAdjustPass, succ) {
@@ -96,8 +113,11 @@ TEST_F(UtestGraphPassesDimensionAdjustPass, succ) {
GraphUtils::AddEdge(op_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));

std::shared_ptr<DimensionAdjustPass> pass = make_shared<DimensionAdjustPass>();
NamesToPass names_to_passes;
EXPECT_EQ(4, graph->GetDirectNodesSize());
ge::Status ret = pass->Run(op_node);
EXPECT_EQ(SUCCESS, ret);
EXPECT_EQ(2, op_node->GetOwnerComputeGraph()->GetDirectNodesSize());
}

TEST_F(UtestGraphPassesDimensionAdjustPass, input_node_is_nullptr) {


Loading…
Cancel
Save