From 18d9518739982e0f97b74265514813b9053efc64 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Mon, 19 Apr 2021 20:14:16 +0800 Subject: [PATCH] add support for train_mode tune --- ge/graph/build/model_builder.cc | 6 ++---- .../ge/graph/build/model_builder_unittest.cc | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc index 08ead6d9..99abd73a 100755 --- a/ge/graph/build/model_builder.cc +++ b/ge/graph/build/model_builder.cc @@ -650,8 +650,7 @@ Status ModelBuilder::SaveAtomicTBEKernel(const OpDescPtr &op_desc) { GELOGI("Node [%s][%s] start recovery extra attr %s from %s", atomic_op_desc->GetName().c_str(), atomic_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); if (!(atomic_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) { - std::string error = "Node" + FmtToStr(atomic_op_desc->GetName()) + "set extra attr" + - FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed"; + std::string error = "Node" + FmtToStr(atomic_op_desc->GetName()) + "set extra tbeKernel attr failed"; GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str()); return ge::FAILED; } @@ -707,8 +706,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); if (!(node_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) { - std::string error = "Node" + FmtToStr(node_op_desc->GetName()) + "set extra attr" + - FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed"; + std::string error = "Node" + FmtToStr(node_op_desc->GetName()) + "set extra tbeKernel attr failed"; GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str()); return ge::FAILED; } diff --git a/tests/ut/ge/graph/build/model_builder_unittest.cc b/tests/ut/ge/graph/build/model_builder_unittest.cc index b9204dbc..d5efc9bb 100644 --- a/tests/ut/ge/graph/build/model_builder_unittest.cc +++ b/tests/ut/ge/graph/build/model_builder_unittest.cc @@ -161,3 +161,22 @@ TEST_F(UtestModelBuilderTest, test_save_atomic_bin) { op_desc->SetExtAttr("atomic_clean_node_ptr", atomic_node); EXPECT_EQ(builder.SaveAtomicTBEKernel(op_desc), SUCCESS); } + +TEST_F(UtestModelBuilderTest, test_model_save) { + Graph2SubGraphInfoList subgraphs; + std::map stream_max_parallel_num; + ge::ComputeGraphPtr graph = make_shared(""); + ge::ModelBuilder builder(0, graph, subgraphs, stream_max_parallel_num, false); + + auto op_desc = make_shared("Conv2d", "Conv2d"); + auto kernel_buffer = static_cast(Buffer(10)); + AttrUtils::SetStr(op_desc, ATTR_NAME_TBE_KERNEL_NAME, "Conv2d"); + AttrUtils::SetBytes(op_desc, ATTR_NAME_TBE_KERNEL_BUFFER, kernel_buffer); + + ge::NodePtr node = graph->AddNode(op_desc); + ge::Model ge_model; + ge::GeModel ge_gemodel; + builder.SaveDataToModel(ge_model, ge_gemodel); + auto tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); + EXPECT_NE(tbe_kernel, nullptr); +}