/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/print_op_pass.h" #include #include "omg/omg_inner_types.h" #include "utils/op_desc_utils.h" using domi::GetContext; namespace ge { class UtestGraphPassesPrintOpPass : public testing::Test { protected: void SetUp() {} void TearDown() {} public: void make_graph(ComputeGraphPtr graph, bool match = true, int flag = 0) { auto data = std::make_shared("Data", DATA); GeTensorDesc tensor_desc_data(GeShape({1, 1, 1, 1})); data->AddInputDesc(tensor_desc_data); data->AddOutputDesc(tensor_desc_data); auto data_node = graph->AddNode(data); auto data1 = std::make_shared("Data", DATA); data1->AddInputDesc(tensor_desc_data); data1->AddOutputDesc(tensor_desc_data); auto data_node1 = graph->AddNode(data1); auto print_desc = std::make_shared("Print", "Print"); print_desc->AddInputDesc(tensor_desc_data); print_desc->AddInputDesc(tensor_desc_data); print_desc->AddOutputDesc(tensor_desc_data); auto print_node = graph->AddNode(print_desc); auto ret_val_desc = std::make_shared("RetVal", "RetVal"); ret_val_desc->AddInputDesc(tensor_desc_data); ret_val_desc->AddOutputDesc(tensor_desc_data); auto ret_val_node = graph->AddNode(ret_val_desc); auto ret = GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), print_node->GetInDataAnchor(0)); ret = GraphUtils::AddEdge(data_node1->GetOutDataAnchor(0), print_node->GetInDataAnchor(1)); ret = GraphUtils::AddEdge(print_node->GetOutDataAnchor(0), ret_val_node->GetInDataAnchor(0)); } }; TEST_F(UtestGraphPassesPrintOpPass, apply_success) { GetContext().out_nodes_map.clear(); ComputeGraphPtr graph = std::make_shared("test_graph"); make_graph(graph); ge::PrintOpPass apply_pass; NamesToPass names_to_pass; names_to_pass.emplace_back("Test", &apply_pass); GEPass pass(graph); Status status = pass.Run(names_to_pass); EXPECT_EQ(SUCCESS, status); } TEST_F(UtestGraphPassesPrintOpPass, param_invalid) { ge::NodePtr node = nullptr; ge::PrintOpPass apply_pass; Status status = apply_pass.Run(node); EXPECT_EQ(ge::PARAM_INVALID, status); } } // namespace ge