/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #define protected public #define private public #include "graph/passes/cast_remove_pass.h" #undef protected #undef private #include "anchor.h" #include "common/debug/log.h" #include "common/debug/memory_dumper.h" #include "common/op/attr_value_util.h" #include "common/types.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/attr_value.h" #include "graph/debug/ge_attr_define.h" #include "inc/pass_manager.h" #include "graph_builder_utils.h" #include #include #include #include "opskernel_manager/ops_kernel_manager.h" #include "omg/omg_inner_types.h" using namespace testing; using namespace ge; using namespace std; class UtestGraphPassesCastRemovePass : public testing::Test { protected: void SetUp() {} void TearDown() {} }; // case1:no net_out_put_node // TEST_F(UtestGraphPassesCastRemovePass, DoFuseProcess) { // std::vector nodes_to_fuse; // auto builder = ut::GraphBuilder("g1"); // auto data = builder.AddNode("data", DATA, 1, 1); // auto cast1 = builder.AddNode("cast1", CAST, 1, 1); // cast1->GetOpDesc()->MutableOutputDesc(0)->SetDataType(DT_FLOAT16); // auto trans = builder.AddNode("trans", TRANSPOSE, 1, 1, FORMAT_NCHW, DT_FLOAT16); // auto cast2 = builder.AddNode("cast2", CAST, 1, 1); // cast2->GetOpDesc()->MutableInputDesc(0)->SetDataType(DT_FLOAT16); // auto net = builder.AddNode("netout", NETOUTPUT, 1, 1); // builder.AddDataEdge(data, 0, cast1, 0); // builder.AddDataEdge(cast1, 0, trans, 0); // builder.AddDataEdge(trans, 0, cast2, 0); // builder.AddDataEdge(cast2, 0, net, 0); // ComputeGraphPtr compute_graph = builder.GetGraph(); // map options; // CastRemovePass cast_remove_pass; // DataType type = DT_FLOAT; // nodes_to_fuse.emplace_back(cast1); // nodes_to_fuse.emplace_back(trans); // nodes_to_fuse.emplace_back(cast2); // OpsKernelManager ops_kernel_manager; // cast_remove_pass.DoFuse(ops_kernel_manager, type, nodes_to_fuse); // EXPECT_EQ(compute_graph->GetAllNodesSize(),5); // std::vector to_be_deleted_cast_index; // to_be_deleted_cast_index.emplace_back(0); // to_be_deleted_cast_index.emplace_back(2); // (void)cast_remove_pass.DoRemoveCast(to_be_deleted_cast_index, nodes_to_fuse); // EXPECT_EQ(compute_graph->GetAllNodesSize(),3); // } TEST_F(UtestGraphPassesCastRemovePass, DoFuseProcess) { std::vector nodes_to_fuse; auto builder = ut::GraphBuilder("g1"); auto data = builder.AddNode("data", DATA, 1, 1); auto cast1 = builder.AddNode("cast1", CAST, 1, 1); cast1->GetOpDesc()->MutableOutputDesc(0)->SetDataType(DT_FLOAT16); auto trans = builder.AddNode("trans", TRANSPOSE, 1, 1, FORMAT_NCHW, DT_FLOAT16); auto cast2 = builder.AddNode("cast2", CAST, 1, 1); cast2->GetOpDesc()->MutableInputDesc(0)->SetDataType(DT_FLOAT16); auto net = builder.AddNode("netout", NETOUTPUT, 1, 1); builder.AddDataEdge(data, 0, cast1, 0); builder.AddDataEdge(cast1, 0, trans, 0); builder.AddDataEdge(trans, 0, cast2, 0); builder.AddDataEdge(cast2, 0, net, 0); ComputeGraphPtr compute_graph = builder.GetGraph(); map options; CastRemovePass cast_remove_pass; DataType type = DT_FLOAT; nodes_to_fuse.emplace_back(cast1); nodes_to_fuse.emplace_back(trans); nodes_to_fuse.emplace_back(cast2); cast_remove_pass.RemoveCast(type, nodes_to_fuse); EXPECT_EQ(compute_graph->GetAllNodesSize(),3); }