@@ -647,6 +647,14 @@ Status ModelBuilder::SaveAtomicTBEKernel(const OpDescPtr &op_desc) { | |||||
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | ||||
tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data)); | tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data)); | ||||
GE_CHECK_NOTNULL(tbe_kernel); | GE_CHECK_NOTNULL(tbe_kernel); | ||||
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", atomic_op_desc->GetName().c_str(), | |||||
atomic_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); | |||||
if (!(atomic_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) { | |||||
std::string error = "Node" + FmtToStr(atomic_op_desc->GetName()) + "set extra attr" + | |||||
FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed"; | |||||
GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str()); | |||||
return ge::FAILED; | |||||
} | |||||
} | } | ||||
} | } | ||||
if (tbe_kernel == nullptr) { | if (tbe_kernel == nullptr) { | ||||
@@ -695,6 +703,15 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { | |||||
GE_CHECK_NOTNULL(kernel_buffer.GetData()); | GE_CHECK_NOTNULL(kernel_buffer.GetData()); | ||||
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | ||||
tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data)); | tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data)); | ||||
GE_CHECK_NOTNULL(tbe_kernel); | |||||
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), | |||||
node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); | |||||
if (!(node_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) { | |||||
std::string error = "Node" + FmtToStr(node_op_desc->GetName()) + "set extra attr" + | |||||
FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed"; | |||||
GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str()); | |||||
return ge::FAILED; | |||||
} | |||||
} | } | ||||
} | } | ||||
GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | ||||
@@ -1686,7 +1686,8 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
return GE_GRAPH_OPTIONS_INVALID); | return GE_GRAPH_OPTIONS_INVALID); | ||||
// ge.graphType | // ge.graphType | ||||
ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); | |||||
ret = | |||||
ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag, options_.build_mode == BUILD_MODE_TUNING); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, | GE_IF_BOOL_EXEC(ret != SUCCESS, | ||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); | GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); | ||||
return GE_GRAPH_OPTIONS_INVALID); | return GE_GRAPH_OPTIONS_INVALID); | ||||
@@ -1728,20 +1729,21 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphManager::ParseTrainGraphFlag(bool &options, bool &option) { | |||||
Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag) { | |||||
std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance(); | ||||
if (ge_instance_ptr == nullptr) { | if (ge_instance_ptr == nullptr) { | ||||
GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); | GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); | ||||
option = false; | |||||
train_flag = false; | |||||
} else if (!ge_instance_ptr->isTrainMode()) { | } else if (!ge_instance_ptr->isTrainMode()) { | ||||
option = false; | |||||
train_flag = false; | |||||
} else { // ge_instance_ptr->isTrainMode() is true | } else { // ge_instance_ptr->isTrainMode() is true | ||||
if (!options) { | |||||
// tune mode no need check | |||||
if (!run_flag && !tune_flag) { | |||||
GELOGE(GE_GRAPH_OPTIONS_INVALID, | GELOGE(GE_GRAPH_OPTIONS_INVALID, | ||||
"Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", options); | |||||
"Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); | |||||
return GE_GRAPH_OPTIONS_INVALID; | return GE_GRAPH_OPTIONS_INVALID; | ||||
} | } | ||||
option = true; | |||||
train_flag = true; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -277,7 +277,7 @@ class GraphManager { | |||||
static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); | static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); | ||||
static Status ParseTrainGraphFlag(bool &options, bool &option); | |||||
static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag); | |||||
static bool IsPerfLevelInvalid(int32_t perf_level); | static bool IsPerfLevelInvalid(int32_t perf_level); | ||||
@@ -26,6 +26,8 @@ | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
#include "graph/passes/pass_utils.h" | #include "graph/passes/pass_utils.h" | ||||
#include "graph/ge_context.h" | |||||
#include "graph/tuning_utils.h" | |||||
namespace ge { | namespace ge { | ||||
NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | ||||
@@ -72,6 +74,12 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | |||||
} | } | ||||
Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { | Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { | ||||
std::string build_mode; | |||||
if (ge::GetContext().GetOption(ge::BUILD_MODE, build_mode) == GRAPH_SUCCESS && build_mode == BUILD_MODE_TUNING) { | |||||
GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), | |||||
compute_graph->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT); | NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT); | ||||
if (output_node == nullptr) { | if (output_node == nullptr) { | ||||
GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); | GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); | ||||