|
@@ -28,6 +28,7 @@ |
|
|
#include "graph/types.h" |
|
|
#include "graph/types.h" |
|
|
#include "graph/utils/graph_utils.h" |
|
|
#include "graph/utils/graph_utils.h" |
|
|
#include "graph/utils/op_desc_utils.h" |
|
|
#include "graph/utils/op_desc_utils.h" |
|
|
|
|
|
#include "inc/kernel.h" |
|
|
#include "inc/kernel_factory.h" |
|
|
#include "inc/kernel_factory.h" |
|
|
#undef protected |
|
|
#undef protected |
|
|
#undef private |
|
|
#undef private |
|
@@ -37,11 +38,27 @@ using namespace testing; |
|
|
|
|
|
|
|
|
namespace ge { |
|
|
namespace ge { |
|
|
|
|
|
|
|
|
|
|
|
class TestExpandDimKernel : public Kernel { |
|
|
|
|
|
public: |
|
|
|
|
|
Status Compute(const NodePtr &node_ptr) override { |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
REGISTER_KERNEL(EXPANDDIMS, TestExpandDimKernel); |
|
|
|
|
|
class TestExpandDimKernelNotChange : public Kernel { |
|
|
|
|
|
public: |
|
|
|
|
|
Status Compute(const NodePtr &node_ptr) override { |
|
|
|
|
|
return NOT_CHANGED; |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
class UtestGraphPassesDimensionAdjustPass : public testing::Test { |
|
|
class UtestGraphPassesDimensionAdjustPass : public testing::Test { |
|
|
protected: |
|
|
protected: |
|
|
void SetUp() {} |
|
|
void SetUp() {} |
|
|
|
|
|
|
|
|
void TearDown() {} |
|
|
|
|
|
|
|
|
void TearDown() { |
|
|
|
|
|
KernelFactory::Instance().creator_map_.clear(); |
|
|
|
|
|
} |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { |
|
|
TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { |
|
@@ -96,8 +113,11 @@ TEST_F(UtestGraphPassesDimensionAdjustPass, succ) { |
|
|
GraphUtils::AddEdge(op_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0)); |
|
|
GraphUtils::AddEdge(op_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0)); |
|
|
|
|
|
|
|
|
std::shared_ptr<DimensionAdjustPass> pass = make_shared<DimensionAdjustPass>(); |
|
|
std::shared_ptr<DimensionAdjustPass> pass = make_shared<DimensionAdjustPass>(); |
|
|
|
|
|
NamesToPass names_to_passes; |
|
|
|
|
|
EXPECT_EQ(4, graph->GetDirectNodesSize()); |
|
|
ge::Status ret = pass->Run(op_node); |
|
|
ge::Status ret = pass->Run(op_node); |
|
|
EXPECT_EQ(SUCCESS, ret); |
|
|
EXPECT_EQ(SUCCESS, ret); |
|
|
|
|
|
EXPECT_EQ(2, op_node->GetOwnerComputeGraph()->GetDirectNodesSize()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
TEST_F(UtestGraphPassesDimensionAdjustPass, input_node_is_nullptr) { |
|
|
TEST_F(UtestGraphPassesDimensionAdjustPass, input_node_is_nullptr) { |
|
|