Browse Source

!1550 add support for train_mode tune

From: @ni100die
Reviewed-by: @ji_chen
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
72d3aed38f
7 changed files with 130 additions and 10 deletions
  1. +15
    -0
      ge/graph/build/model_builder.cc
  2. +8
    -9
      ge/graph/manager/graph_manager.cc
  3. +1
    -1
      ge/graph/manager/graph_manager.h
  4. +12
    -0
      ge/graph/passes/global_step_insert_pass.cc
  5. +1
    -0
      tests/ut/ge/CMakeLists.txt
  6. +19
    -0
      tests/ut/ge/graph/build/model_builder_unittest.cc
  7. +74
    -0
      tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc

+ 15
- 0
ge/graph/build/model_builder.cc View File

@@ -647,6 +647,13 @@ Status ModelBuilder::SaveAtomicTBEKernel(const OpDescPtr &op_desc) {
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize());
tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data)); tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data));
GE_CHECK_NOTNULL(tbe_kernel); GE_CHECK_NOTNULL(tbe_kernel);
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", atomic_op_desc->GetName().c_str(),
atomic_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str());
if (!(atomic_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) {
std::string error = "Node" + FmtToStr(atomic_op_desc->GetName()) + "set extra tbeKernel attr failed";
GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str());
return ge::FAILED;
}
} }
} }
if (tbe_kernel == nullptr) { if (tbe_kernel == nullptr) {
@@ -695,6 +702,14 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) {
GE_CHECK_NOTNULL(kernel_buffer.GetData()); GE_CHECK_NOTNULL(kernel_buffer.GetData());
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize());
tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data)); tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data));
GE_CHECK_NOTNULL(tbe_kernel);
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(),
node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str());
if (!(node_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) {
std::string error = "Node" + FmtToStr(node_op_desc->GetName()) + "set extra tbeKernel attr failed";
GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str());
return ge::FAILED;
}
} }
} }
GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue);


+ 8
- 9
ge/graph/manager/graph_manager.cc View File

@@ -1747,7 +1747,8 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti
return GE_GRAPH_OPTIONS_INVALID); return GE_GRAPH_OPTIONS_INVALID);


// ge.graphType // ge.graphType
ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag);
ret =
ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag);
GE_IF_BOOL_EXEC(ret != SUCCESS, GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid");
return GE_GRAPH_OPTIONS_INVALID); return GE_GRAPH_OPTIONS_INVALID);
@@ -1789,20 +1790,18 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti
return SUCCESS; return SUCCESS;
} }


Status GraphManager::ParseTrainGraphFlag(bool &options, bool &option) {
Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag) {
std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance(); std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance();
if (ge_instance_ptr == nullptr) { if (ge_instance_ptr == nullptr) {
GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized");
option = false;
train_flag = false;
} else if (!ge_instance_ptr->isTrainMode()) { } else if (!ge_instance_ptr->isTrainMode()) {
option = false;
train_flag = false;
} else { // ge_instance_ptr->isTrainMode() is true } else { // ge_instance_ptr->isTrainMode() is true
if (!options) {
GELOGE(GE_GRAPH_OPTIONS_INVALID,
"Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", options);
return GE_GRAPH_OPTIONS_INVALID;
train_flag = true;
if (!run_flag) {
GELOGW("Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag);
} }
option = true;
} }
return SUCCESS; return SUCCESS;
} }


+ 1
- 1
ge/graph/manager/graph_manager.h View File

@@ -292,7 +292,7 @@ class GraphManager {


static Status ParseParallelNum(const std::string &parallel_num, const std::string &key, int &num); static Status ParseParallelNum(const std::string &parallel_num, const std::string &key, int &num);


static Status ParseTrainGraphFlag(bool &options, bool &option);
static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag);


static bool IsPerfLevelInvalid(int32_t perf_level); static bool IsPerfLevelInvalid(int32_t perf_level);




+ 12
- 0
ge/graph/passes/global_step_insert_pass.cc View File

@@ -26,6 +26,11 @@
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_var_manager.h"
#include "graph/passes/pass_utils.h" #include "graph/passes/pass_utils.h"
#include "graph/ge_context.h"

namespace {
const char *const kFlagOff = "0";
} // namespace


namespace ge { namespace ge {
NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph,
@@ -72,6 +77,13 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph,
} }


Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) {
// run_flag off means offline, no need insert global step node which type is variable
std::string run_flag;
if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == kFlagOff) {
GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(),
compute_graph->GetName().c_str());
return SUCCESS;
}
NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT); NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT);
if (output_node == nullptr) { if (output_node == nullptr) {
GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID());


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -693,6 +693,7 @@ set(PASS_TEST_FILES
"graph/passes/stop_gradient_pass_unittest.cc" "graph/passes/stop_gradient_pass_unittest.cc"
"graph/passes/prevent_gradient_pass_unittest.cc" "graph/passes/prevent_gradient_pass_unittest.cc"
"graph/passes/identity_pass_unittest.cc" "graph/passes/identity_pass_unittest.cc"
"graph/passes/global_step_insert_pass_unittest.cc"
"graph/passes/placeholder_with_default_pass_unittest.cc" "graph/passes/placeholder_with_default_pass_unittest.cc"
"graph/passes/snapshot_pass_unittest.cc" "graph/passes/snapshot_pass_unittest.cc"
"graph/passes/guarantee_const_pass_unittest.cc" "graph/passes/guarantee_const_pass_unittest.cc"


+ 19
- 0
tests/ut/ge/graph/build/model_builder_unittest.cc View File

@@ -161,3 +161,22 @@ TEST_F(UtestModelBuilderTest, test_save_atomic_bin) {
op_desc->SetExtAttr("atomic_clean_node_ptr", atomic_node); op_desc->SetExtAttr("atomic_clean_node_ptr", atomic_node);
EXPECT_EQ(builder.SaveAtomicTBEKernel(op_desc), SUCCESS); EXPECT_EQ(builder.SaveAtomicTBEKernel(op_desc), SUCCESS);
} }

TEST_F(UtestModelBuilderTest, test_model_save) {
Graph2SubGraphInfoList subgraphs;
std::map<std::string, int> stream_max_parallel_num;
ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("");
ge::ModelBuilder builder(0, graph, subgraphs, stream_max_parallel_num, false);

auto op_desc = make_shared<OpDesc>("Conv2d", "Conv2d");
auto kernel_buffer = static_cast<GeAttrValue::BYTES>(Buffer(10));
AttrUtils::SetStr(op_desc, ATTR_NAME_TBE_KERNEL_NAME, "Conv2d");
AttrUtils::SetBytes(op_desc, ATTR_NAME_TBE_KERNEL_BUFFER, kernel_buffer);

ge::NodePtr node = graph->AddNode(op_desc);
ge::Model ge_model;
ge::GeModel ge_gemodel;
builder.SaveDataToModel(ge_model, ge_gemodel);
auto tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
EXPECT_NE(tbe_kernel, nullptr);
}

+ 74
- 0
tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc View File

@@ -0,0 +1,74 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#define protected public
#define private public
#include "graph/passes/global_step_insert_pass.h"

#include "common/op/ge_op_utils.h"
#include "common/types.h"
#include "graph/anchor.h"
#include "graph/attr_value.h"
#include "graph/compute_graph.h"
#include "graph/op_desc.h"
#include "graph/passes/base_pass.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/tuning_utils.h"
#include "graph_builder_utils.h"
#include "graph/ge_context.h"
#include "graph/ge_local_context.h"
#include "inc/pass_manager.h"
#undef protected
#undef private

using namespace std;
using namespace testing;
using namespace ge;

class UtestGlobalStepInsertPass : public Test {
protected:
};

static ComputeGraphPtr BuildGraph1() {
ge::ut::GraphBuilder builder("g1");
auto var1 = builder.AddNode("var1", "Variable", 0, 1);
auto var2 = builder.AddNode("var2", "Variable", 0, 1);
auto identity1 = builder.AddNode("identity1", "Identity", 1, 1);
auto out = builder.AddNode("out", "NetOutput", 1, 1);

builder.AddDataEdge(var1, 0, identity1, 0);
builder.AddControlEdge(var2, identity1);
builder.AddDataEdge(identity1, 0, out, 0);
return builder.GetGraph();
}

TEST_F(UtestGlobalStepInsertPass, skip_insert) {
auto graph = BuildGraph1();
std::string build_mode;
std::map<string, string> options_map;
options_map.insert({ge::RUN_FLAG, "0"});
ge::GetThreadLocalContext().SetGraphOption(options_map);
GlobalStepInsertPass pass;
Status status = pass.Run(graph);
EXPECT_EQ(status, SUCCESS);
NodePtr found_node = graph->FindNode(NODE_NAME_GLOBAL_STEP);
EXPECT_EQ(found_node, nullptr);
}

Loading…
Cancel
Save