From bd2173a9f3bf2e9efe82b0f9465f06a5d97a64e2 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Wed, 21 Apr 2021 10:24:52 +0800 Subject: [PATCH] add support for train_mode tune --- tests/ut/ge/CMakeLists.txt | 1 + .../global_step_insert_pass_unittest.cc | 74 +++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index f2f08106..0c11c9d2 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -691,6 +691,7 @@ set(PASS_TEST_FILES "graph/passes/stop_gradient_pass_unittest.cc" "graph/passes/prevent_gradient_pass_unittest.cc" "graph/passes/identity_pass_unittest.cc" + "graph/passes/global_step_insert_pass_unittest.cc" "graph/passes/placeholder_with_default_pass_unittest.cc" "graph/passes/snapshot_pass_unittest.cc" "graph/passes/guarantee_const_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc new file mode 100644 index 00000000..98e303c7 --- /dev/null +++ b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define protected public +#define private public +#include "graph/passes/global_step_insert_pass.h" + +#include "common/op/ge_op_utils.h" +#include "common/types.h" +#include "graph/anchor.h" +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/op_desc.h" +#include "graph/passes/base_pass.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/tuning_utils.h" +#include "graph_builder_utils.h" +#include "graph/ge_context.h" +#include "graph/ge_local_context.h" +#include "inc/pass_manager.h" +#undef protected +#undef private + +using namespace std; +using namespace testing; +using namespace ge; + +class UtestGlobalStepInsertPass : public Test { + protected: +}; + +static ComputeGraphPtr BuildGraph1() { + ge::ut::GraphBuilder builder("g1"); + auto var1 = builder.AddNode("var1", "Variable", 0, 1); + auto var2 = builder.AddNode("var2", "Variable", 0, 1); + auto identity1 = builder.AddNode("identity1", "Identity", 1, 1); + auto out = builder.AddNode("out", "NetOutput", 1, 1); + + builder.AddDataEdge(var1, 0, identity1, 0); + builder.AddControlEdge(var2, identity1); + builder.AddDataEdge(identity1, 0, out, 0); + return builder.GetGraph(); +} + +TEST_F(UtestGlobalStepInsertPass, skip_tune) { + auto graph = BuildGraph1(); + std::string build_mode; + std::map options_map; + options_map.insert({ge::BUILD_MODE, BUILD_MODE_TUNING}); + ge::GetThreadLocalContext().SetGraphOption(options_map); + GlobalStepInsertPass pass; + Status status = pass.Run(graph); + EXPECT_EQ(status, SUCCESS); + NodePtr found_node = graph->FindNode(NODE_NAME_GLOBAL_STEP); + EXPECT_EQ(found_node, nullptr); +}