Browse Source

add support for train_mode tune

tags/v1.3.0
gengchao4@huawei.com 3 years ago
parent
commit
bd2173a9f3
2 changed files with 75 additions and 0 deletions
  1. +1
    -0
      tests/ut/ge/CMakeLists.txt
  2. +74
    -0
      tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc

+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -691,6 +691,7 @@ set(PASS_TEST_FILES
"graph/passes/stop_gradient_pass_unittest.cc"
"graph/passes/prevent_gradient_pass_unittest.cc"
"graph/passes/identity_pass_unittest.cc"
"graph/passes/global_step_insert_pass_unittest.cc"
"graph/passes/placeholder_with_default_pass_unittest.cc"
"graph/passes/snapshot_pass_unittest.cc"
"graph/passes/guarantee_const_pass_unittest.cc"


+ 74
- 0
tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc View File

@@ -0,0 +1,74 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#define protected public
#define private public
#include "graph/passes/global_step_insert_pass.h"

#include "common/op/ge_op_utils.h"
#include "common/types.h"
#include "graph/anchor.h"
#include "graph/attr_value.h"
#include "graph/compute_graph.h"
#include "graph/op_desc.h"
#include "graph/passes/base_pass.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/tuning_utils.h"
#include "graph_builder_utils.h"
#include "graph/ge_context.h"
#include "graph/ge_local_context.h"
#include "inc/pass_manager.h"
#undef protected
#undef private

using namespace std;
using namespace testing;
using namespace ge;

class UtestGlobalStepInsertPass : public Test {
protected:
};

static ComputeGraphPtr BuildGraph1() {
ge::ut::GraphBuilder builder("g1");
auto var1 = builder.AddNode("var1", "Variable", 0, 1);
auto var2 = builder.AddNode("var2", "Variable", 0, 1);
auto identity1 = builder.AddNode("identity1", "Identity", 1, 1);
auto out = builder.AddNode("out", "NetOutput", 1, 1);

builder.AddDataEdge(var1, 0, identity1, 0);
builder.AddControlEdge(var2, identity1);
builder.AddDataEdge(identity1, 0, out, 0);
return builder.GetGraph();
}

TEST_F(UtestGlobalStepInsertPass, skip_tune) {
auto graph = BuildGraph1();
std::string build_mode;
std::map<string, string> options_map;
options_map.insert({ge::BUILD_MODE, BUILD_MODE_TUNING});
ge::GetThreadLocalContext().SetGraphOption(options_map);
GlobalStepInsertPass pass;
Status status = pass.Run(graph);
EXPECT_EQ(status, SUCCESS);
NodePtr found_node = graph->FindNode(NODE_NAME_GLOBAL_STEP);
EXPECT_EQ(found_node, nullptr);
}

Loading…
Cancel
Save