/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #define protected public #define private public #include "graph/passes/infershape_pass.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/graph_utils.h" #include "graph/operator_factory.h" #include "graph/operator_reg.h" #include "graph_builder_utils.h" using namespace std; using namespace testing; namespace ge { namespace { // do nothing stub infer_func const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; // infer from input to output stub infer_func (input size == output size) const auto stub_mapping_func = [](Operator &op) { size_t in_num = op.GetInputsSize(); for (size_t i = 0; i < in_num; ++i) { auto in_desc = op.GetInputDesc(i); auto out_desc = op.GetOutputDesc(i); out_desc.SetShape(in_desc.GetShape()); out_desc.SetDataType(in_desc.GetDataType()); op.UpdateOutputDesc(out_desc.GetName(), out_desc); } return GRAPH_SUCCESS; }; // merge infer_func // while infer_func const auto while_infer_func = [](Operator &op) { size_t in_num = op.GetInputsSize(); size_t out_num = op.GetOutputsSize(); if (in_num != out_num) { return GRAPH_FAILED; } bool need_infer_again = false; for (size_t i = 0; i < in_num; ++i) { auto in_desc = op.GetDynamicInputDesc("input", i); auto out_desc = op.GetDynamicOutputDesc("output", i); auto data_shape = in_desc.GetShape(); auto out_shape = out_desc.GetShape(); if(out_shape.GetDims() == DUMMY_SHAPE){ return GRAPH_SUCCESS; } // check datatype between output and input if (in_desc.GetDataType() != out_desc.GetDataType()) { return GRAPH_FAILED; } if (data_shape.GetDims() != out_shape.GetDims()) { need_infer_again = true; if (data_shape.GetDimNum() != out_shape.GetDimNum()) { in_desc.SetUnknownDimNumShape(); } else { size_t data_dim_num = data_shape.GetDimNum(); std::vector> data_shape_range = {data_dim_num, std::make_pair(1, UNKNOWN_DIM)}; for (size_t j = 0; j < data_dim_num; ++j) { if (data_shape.GetDim(j) != out_shape.GetDim(j)) { data_shape.SetDim(j, UNKNOWN_DIM); } if (data_shape.GetDim(j) != UNKNOWN_DIM) { data_shape_range[j] = std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j)); } } in_desc.SetShape(data_shape); in_desc.SetShapeRange(data_shape_range); } op.UpdateDynamicOutputDesc("output", i, in_desc); op.UpdateDynamicInputDesc("input", i, in_desc); } } return need_infer_again ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS; }; } class UtestGraphInfershapePass : public testing::Test { protected: void SetUp() {} void TearDown() {} }; static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num, std::function infer_func = stub_func) { OpDescPtr op_desc = std::make_shared(name, type); op_desc->SetStreamId(0); static int32_t index = 0; op_desc->SetId(index++); GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); TensorUtils::SetSize(tensor, 512); vector input_offset; for (int i = 0; i < in_num; i++) { op_desc->AddInputDesc(tensor); input_offset.emplace_back(1024); } op_desc->SetInputOffset(input_offset); vector output_offset; for (int i = 0; i < out_num; i++) { op_desc->AddOutputDesc(tensor); output_offset.emplace_back(1024); } op_desc->SetOutputOffset(output_offset); op_desc->SetWorkspace({}); op_desc->SetWorkspaceBytes({}); op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); op_desc->AddInferFunc(infer_func); return graph.AddNode(op_desc); } /* TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); string type = "AddN"; auto addn_op_desc = std::make_shared("AddN", type); addn_op_desc->AddInputDesc(ge_tensor_desc); addn_op_desc->AddOutputDesc(ge_tensor_desc); auto graph = std::make_shared("test"); auto addn_node = std::make_shared(addn_op_desc, graph); addn_node->Init(); InferShapePass infershape_pass; EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); } */ TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { auto graph = std::make_shared("test"); auto no_op_desc = std::make_shared("No", "NoOp"); auto no_op_node = graph->AddNode(no_op_desc); AttrUtils::SetBool(no_op_desc, "_need_infer_again", false); InferShapePass infershape_pass; infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); } TEST_F(UtestGraphInfershapePass, infer_from_pre_to_next) { /* * cast->shape */ auto graph = std::make_shared("test_infer_shape"); auto data1 = CreateNode(*graph, "dataq", DATA, 0, 1); auto cast1 = CreateNode(*graph, "cast1", CAST, 1, 1, stub_mapping_func); auto cast_in_desc = cast1->GetOpDesc()->MutableInputDesc(0); cast_in_desc->SetShape(GeShape({1,2,3})); cast_in_desc->SetDataType(DT_INT32); auto transdata1 = CreateNode(*graph, "transdata1", TRANSDATA, 1, 1, stub_mapping_func); GraphUtils::AddEdge(data1->GetOutDataAnchor(0), cast1->GetInDataAnchor(0)); GraphUtils::AddEdge(cast1->GetOutDataAnchor(0), transdata1->GetInDataAnchor(0)); // check before infer cast1 auto cast_before = graph->FindNode("cast1"); vector expect_cast1_shape_dim = {1,2,3}; auto real_cast1_before_shape_dim = cast_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims(); auto transdata1_before = graph->FindNode("transdata1"); vector expect_transdata1_shape_dim = {}; auto real_transdata1_before_shape_dim = transdata1_before->GetOpDesc()->GetInputDesc(0).GetShape().GetDims(); EXPECT_EQ(real_cast1_before_shape_dim, expect_cast1_shape_dim); EXPECT_EQ(real_transdata1_before_shape_dim, expect_transdata1_shape_dim); // run infershape pass InferShapePass infer_shape_pass; infer_shape_pass.Run(cast_before); // check cast1 add transdata1 to repass_immediately infer_shape_pass.GetNodesNeedRePassImmediately(); EXPECT_TRUE(!infer_shape_pass.GetNodesNeedRePassImmediately().empty()); // check transdata input_shape & datatype after infer auto transdata1_after = graph->FindNode("transdata1"); auto transdata1_opdesc = transdata1_before->GetOpDesc(); auto real_transdata1_after_shape_dim = transdata1_opdesc->GetInputDesc(0).GetShape().GetDims(); EXPECT_EQ(real_transdata1_after_shape_dim, expect_cast1_shape_dim); auto transdata1_datatype_after = transdata1_opdesc->GetInputDesc(0).GetDataType(); EXPECT_EQ(transdata1_datatype_after, DT_INT32); } TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { /******************************************************************************* * Exit Identify * \ / \. * \ / \. * Switch Add * / | | * / | | * / | | * LoopCond | | * \ | | * \ | | * \ | | * Less | | * \ | NextIteration * \ | | * \ | | * Merge <---------| * | * | * Enter ******************************************************************************/ auto graph = std::make_shared("test_infer_shape"); auto data1 = CreateNode(*graph, "data", DATA, 1, 1); auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); auto less1 = CreateNode(*graph, "less", LESS, 2, 1); auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); auto add1 = CreateNode(*graph, "add", ADD, 2, 1); auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); GEPass ge_passes(graph); NamesToPass names_to_passes; InferShapePass infer_shape_pass; names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); EXPECT_EQ(ge_passes.Run(names_to_passes), SUCCESS); } } // namespace ge