From: @ni100die Reviewed-by: @ji_chen Signed-off-by:pull/1550/MERGE
| @@ -647,6 +647,13 @@ Status ModelBuilder::SaveAtomicTBEKernel(const OpDescPtr &op_desc) { | |||||
| std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | ||||
| tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data)); | tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data)); | ||||
| GE_CHECK_NOTNULL(tbe_kernel); | GE_CHECK_NOTNULL(tbe_kernel); | ||||
| GELOGI("Node [%s][%s] start recovery extra attr %s from %s", atomic_op_desc->GetName().c_str(), | |||||
| atomic_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); | |||||
| if (!(atomic_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) { | |||||
| std::string error = "Node" + FmtToStr(atomic_op_desc->GetName()) + "set extra tbeKernel attr failed"; | |||||
| GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str()); | |||||
| return ge::FAILED; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| if (tbe_kernel == nullptr) { | if (tbe_kernel == nullptr) { | ||||
| @@ -695,6 +702,14 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { | |||||
| GE_CHECK_NOTNULL(kernel_buffer.GetData()); | GE_CHECK_NOTNULL(kernel_buffer.GetData()); | ||||
| std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | ||||
| tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data)); | tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data)); | ||||
| GE_CHECK_NOTNULL(tbe_kernel); | |||||
| GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), | |||||
| node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); | |||||
| if (!(node_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) { | |||||
| std::string error = "Node" + FmtToStr(node_op_desc->GetName()) + "set extra tbeKernel attr failed"; | |||||
| GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str()); | |||||
| return ge::FAILED; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | ||||
| @@ -1747,7 +1747,8 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
| return GE_GRAPH_OPTIONS_INVALID); | return GE_GRAPH_OPTIONS_INVALID); | ||||
| // ge.graphType | // ge.graphType | ||||
| ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); | |||||
| ret = | |||||
| ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); | |||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, | GE_IF_BOOL_EXEC(ret != SUCCESS, | ||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); | GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); | ||||
| return GE_GRAPH_OPTIONS_INVALID); | return GE_GRAPH_OPTIONS_INVALID); | ||||
| @@ -1789,20 +1790,18 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphManager::ParseTrainGraphFlag(bool &options, bool &option) { | |||||
| Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag) { | |||||
| std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance(); | ||||
| if (ge_instance_ptr == nullptr) { | if (ge_instance_ptr == nullptr) { | ||||
| GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); | GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); | ||||
| option = false; | |||||
| train_flag = false; | |||||
| } else if (!ge_instance_ptr->isTrainMode()) { | } else if (!ge_instance_ptr->isTrainMode()) { | ||||
| option = false; | |||||
| train_flag = false; | |||||
| } else { // ge_instance_ptr->isTrainMode() is true | } else { // ge_instance_ptr->isTrainMode() is true | ||||
| if (!options) { | |||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, | |||||
| "Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", options); | |||||
| return GE_GRAPH_OPTIONS_INVALID; | |||||
| train_flag = true; | |||||
| if (!run_flag) { | |||||
| GELOGW("Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); | |||||
| } | } | ||||
| option = true; | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -292,7 +292,7 @@ class GraphManager { | |||||
| static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); | static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); | ||||
| static Status ParseTrainGraphFlag(bool &options, bool &option); | |||||
| static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag); | |||||
| static bool IsPerfLevelInvalid(int32_t perf_level); | static bool IsPerfLevelInvalid(int32_t perf_level); | ||||
| @@ -26,6 +26,11 @@ | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph/passes/pass_utils.h" | #include "graph/passes/pass_utils.h" | ||||
| #include "graph/ge_context.h" | |||||
| namespace { | |||||
| const char *const kFlagOff = "0"; | |||||
| } // namespace | |||||
| namespace ge { | namespace ge { | ||||
| NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | ||||
| @@ -72,6 +77,13 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | |||||
| } | } | ||||
| Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { | Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { | ||||
| // run_flag off means offline, no need insert global step node which type is variable | |||||
| std::string run_flag; | |||||
| if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == kFlagOff) { | |||||
| GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), | |||||
| compute_graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT); | NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT); | ||||
| if (output_node == nullptr) { | if (output_node == nullptr) { | ||||
| GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); | GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); | ||||
| @@ -693,6 +693,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/stop_gradient_pass_unittest.cc" | "graph/passes/stop_gradient_pass_unittest.cc" | ||||
| "graph/passes/prevent_gradient_pass_unittest.cc" | "graph/passes/prevent_gradient_pass_unittest.cc" | ||||
| "graph/passes/identity_pass_unittest.cc" | "graph/passes/identity_pass_unittest.cc" | ||||
| "graph/passes/global_step_insert_pass_unittest.cc" | |||||
| "graph/passes/placeholder_with_default_pass_unittest.cc" | "graph/passes/placeholder_with_default_pass_unittest.cc" | ||||
| "graph/passes/snapshot_pass_unittest.cc" | "graph/passes/snapshot_pass_unittest.cc" | ||||
| "graph/passes/guarantee_const_pass_unittest.cc" | "graph/passes/guarantee_const_pass_unittest.cc" | ||||
| @@ -161,3 +161,22 @@ TEST_F(UtestModelBuilderTest, test_save_atomic_bin) { | |||||
| op_desc->SetExtAttr("atomic_clean_node_ptr", atomic_node); | op_desc->SetExtAttr("atomic_clean_node_ptr", atomic_node); | ||||
| EXPECT_EQ(builder.SaveAtomicTBEKernel(op_desc), SUCCESS); | EXPECT_EQ(builder.SaveAtomicTBEKernel(op_desc), SUCCESS); | ||||
| } | } | ||||
| TEST_F(UtestModelBuilderTest, test_model_save) { | |||||
| Graph2SubGraphInfoList subgraphs; | |||||
| std::map<std::string, int> stream_max_parallel_num; | |||||
| ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>(""); | |||||
| ge::ModelBuilder builder(0, graph, subgraphs, stream_max_parallel_num, false); | |||||
| auto op_desc = make_shared<OpDesc>("Conv2d", "Conv2d"); | |||||
| auto kernel_buffer = static_cast<GeAttrValue::BYTES>(Buffer(10)); | |||||
| AttrUtils::SetStr(op_desc, ATTR_NAME_TBE_KERNEL_NAME, "Conv2d"); | |||||
| AttrUtils::SetBytes(op_desc, ATTR_NAME_TBE_KERNEL_BUFFER, kernel_buffer); | |||||
| ge::NodePtr node = graph->AddNode(op_desc); | |||||
| ge::Model ge_model; | |||||
| ge::GeModel ge_gemodel; | |||||
| builder.SaveDataToModel(ge_model, ge_gemodel); | |||||
| auto tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | |||||
| EXPECT_NE(tbe_kernel, nullptr); | |||||
| } | |||||
| @@ -0,0 +1,74 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/passes/global_step_insert_pass.h" | |||||
| #include "common/op/ge_op_utils.h" | |||||
| #include "common/types.h" | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/attr_value.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/passes/base_pass.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/tuning_utils.h" | |||||
| #include "graph_builder_utils.h" | |||||
| #include "graph/ge_context.h" | |||||
| #include "graph/ge_local_context.h" | |||||
| #include "inc/pass_manager.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| using namespace ge; | |||||
| class UtestGlobalStepInsertPass : public Test { | |||||
| protected: | |||||
| }; | |||||
| static ComputeGraphPtr BuildGraph1() { | |||||
| ge::ut::GraphBuilder builder("g1"); | |||||
| auto var1 = builder.AddNode("var1", "Variable", 0, 1); | |||||
| auto var2 = builder.AddNode("var2", "Variable", 0, 1); | |||||
| auto identity1 = builder.AddNode("identity1", "Identity", 1, 1); | |||||
| auto out = builder.AddNode("out", "NetOutput", 1, 1); | |||||
| builder.AddDataEdge(var1, 0, identity1, 0); | |||||
| builder.AddControlEdge(var2, identity1); | |||||
| builder.AddDataEdge(identity1, 0, out, 0); | |||||
| return builder.GetGraph(); | |||||
| } | |||||
| TEST_F(UtestGlobalStepInsertPass, skip_insert) { | |||||
| auto graph = BuildGraph1(); | |||||
| std::string build_mode; | |||||
| std::map<string, string> options_map; | |||||
| options_map.insert({ge::RUN_FLAG, "0"}); | |||||
| ge::GetThreadLocalContext().SetGraphOption(options_map); | |||||
| GlobalStepInsertPass pass; | |||||
| Status status = pass.Run(graph); | |||||
| EXPECT_EQ(status, SUCCESS); | |||||
| NodePtr found_node = graph->FindNode(NODE_NAME_GLOBAL_STEP); | |||||
| EXPECT_EQ(found_node, nullptr); | |||||
| } | |||||