|
@@ -26,12 +26,9 @@ |
|
|
#include "graph/operator_factory.h" |
|
|
#include "graph/operator_factory.h" |
|
|
#include "graph/operator_reg.h" |
|
|
#include "graph/operator_reg.h" |
|
|
#include "graph_builder_utils.h" |
|
|
#include "graph_builder_utils.h" |
|
|
#undef protected |
|
|
|
|
|
#undef private |
|
|
|
|
|
|
|
|
|
|
|
using namespace std; |
|
|
using namespace std; |
|
|
using namespace testing; |
|
|
using namespace testing; |
|
|
using namespace ge; |
|
|
|
|
|
namespace ge { |
|
|
namespace ge { |
|
|
class UtestGraphInfershapePass : public testing::Test { |
|
|
class UtestGraphInfershapePass : public testing::Test { |
|
|
protected: |
|
|
protected: |
|
@@ -52,4 +49,17 @@ TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { |
|
|
InferShapePass infershape_pass; |
|
|
InferShapePass infershape_pass; |
|
|
EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); |
|
|
EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { |
|
|
|
|
|
auto graph = std::make_shared<ComputeGraph>("test"); |
|
|
|
|
|
|
|
|
|
|
|
auto no_op_desc = std::make_shared<OpDesc>("No", "NoOp"); |
|
|
|
|
|
auto no_op_node = graph->AddNode(no_op_desc); |
|
|
|
|
|
AttrUtils::SetBool(no_op_desc, "_need_infer_again", false); |
|
|
|
|
|
|
|
|
|
|
|
InferShapePass infershape_pass; |
|
|
|
|
|
infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; |
|
|
|
|
|
EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
} // namespace ge |
|
|
} // namespace ge |