/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #define protected public #define private public #include "graph/passes/infershape_pass.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/graph_utils.h" #include "graph/operator_factory.h" #include "graph/operator_reg.h" #include "graph_builder_utils.h" using namespace std; using namespace testing; namespace ge { class UtestGraphInfershapePass : public testing::Test { protected: void SetUp() {} void TearDown() {} }; static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { OpDescPtr op_desc = std::make_shared(name, type); op_desc->SetStreamId(0); static int32_t index = 0; op_desc->SetId(index++); GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); TensorUtils::SetSize(tensor, 512); vector input_offset; for (int i = 0; i < in_num; i++) { op_desc->AddInputDesc(tensor); input_offset.emplace_back(1024); } op_desc->SetInputOffset(input_offset); vector output_offset; for (int i = 0; i < out_num; i++) { op_desc->AddOutputDesc(tensor); output_offset.emplace_back(1024); } op_desc->SetOutputOffset(output_offset); op_desc->SetWorkspace({}); op_desc->SetWorkspaceBytes({}); op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; op_desc->AddInferFunc(stub_func); op_desc->AddInferFormatFunc(stub_func); op_desc->AddVerifierFunc(stub_func); return graph.AddNode(op_desc); } TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); string type = "AddN"; auto addn_op_desc = std::make_shared("AddN", type); addn_op_desc->AddInputDesc(ge_tensor_desc); addn_op_desc->AddOutputDesc(ge_tensor_desc); auto graph = std::make_shared("test"); auto addn_node = std::make_shared(addn_op_desc, graph); addn_node->Init(); InferShapePass infershape_pass; EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); } TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { auto graph = std::make_shared("test"); auto no_op_desc = std::make_shared("No", "NoOp"); auto no_op_node = graph->AddNode(no_op_desc); AttrUtils::SetBool(no_op_desc, "_need_infer_again", false); InferShapePass infershape_pass; infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); } TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { /******************************************************************************* * Exit Identify * \ / \. * \ / \. * Switch Add * / | | * / | | * / | | * LoopCond | | * \ | | * \ | | * \ | | * Less | | * \ | NextIteration * \ | | * \ | | * Merge <---------| * | * | * Enter ******************************************************************************/ auto graph = std::make_shared("test_infer_shape"); auto data1 = CreateNode(*graph, "data", DATA, 1, 1); auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); auto less1 = CreateNode(*graph, "less", LESS, 2, 1); auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); auto add1 = CreateNode(*graph, "add", ADD, 2, 1); auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); GEPass ge_passes(graph); NamesToPass names_to_passes; InferShapePass infer_shape_pass; names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); EXPECT_EQ(ge_passes.Run(names_to_passes), SUCCESS); } } // namespace ge