Browse Source

add support for train_mode tune

tags/v1.3.0
gengchao4@huawei.com 3 years ago
parent
commit
18d9518739
2 changed files with 21 additions and 4 deletions
  1. +2
    -4
      ge/graph/build/model_builder.cc
  2. +19
    -0
      tests/ut/ge/graph/build/model_builder_unittest.cc

+ 2
- 4
ge/graph/build/model_builder.cc View File

@@ -650,8 +650,7 @@ Status ModelBuilder::SaveAtomicTBEKernel(const OpDescPtr &op_desc) {
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", atomic_op_desc->GetName().c_str(),
atomic_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str());
if (!(atomic_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) {
std::string error = "Node" + FmtToStr(atomic_op_desc->GetName()) + "set extra attr" +
FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed";
std::string error = "Node" + FmtToStr(atomic_op_desc->GetName()) + "set extra tbeKernel attr failed";
GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str());
return ge::FAILED;
}
@@ -707,8 +706,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) {
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(),
node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str());
if (!(node_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) {
std::string error = "Node" + FmtToStr(node_op_desc->GetName()) + "set extra attr" +
FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed";
std::string error = "Node" + FmtToStr(node_op_desc->GetName()) + "set extra tbeKernel attr failed";
GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str());
return ge::FAILED;
}


+ 19
- 0
tests/ut/ge/graph/build/model_builder_unittest.cc View File

@@ -161,3 +161,22 @@ TEST_F(UtestModelBuilderTest, test_save_atomic_bin) {
op_desc->SetExtAttr("atomic_clean_node_ptr", atomic_node);
EXPECT_EQ(builder.SaveAtomicTBEKernel(op_desc), SUCCESS);
}

TEST_F(UtestModelBuilderTest, test_model_save) {
Graph2SubGraphInfoList subgraphs;
std::map<std::string, int> stream_max_parallel_num;
ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("");
ge::ModelBuilder builder(0, graph, subgraphs, stream_max_parallel_num, false);

auto op_desc = make_shared<OpDesc>("Conv2d", "Conv2d");
auto kernel_buffer = static_cast<GeAttrValue::BYTES>(Buffer(10));
AttrUtils::SetStr(op_desc, ATTR_NAME_TBE_KERNEL_NAME, "Conv2d");
AttrUtils::SetBytes(op_desc, ATTR_NAME_TBE_KERNEL_BUFFER, kernel_buffer);

ge::NodePtr node = graph->AddNode(op_desc);
ge::Model ge_model;
ge::GeModel ge_gemodel;
builder.SaveDataToModel(ge_model, ge_gemodel);
auto tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
EXPECT_NE(tbe_kernel, nullptr);
}

Loading…
Cancel
Save