|
|
@@ -122,6 +122,7 @@ const char *const kVectorEngine = "VectorEngine"; |
|
|
|
const char *const kAIcoreEngine = "AIcoreEngine"; |
|
|
|
const int32_t kDynamicDimsTypeIsGetNext = 0; |
|
|
|
const int32_t kDynamicDimsTypeIsData = 1; |
|
|
|
const int32_t kBase = 10; |
|
|
|
const char *const kGetNextName = "IteratorV2"; |
|
|
|
const uint32_t kInitGraphCount = 1; |
|
|
|
const uint32_t kNotAdded = 0; |
|
|
@@ -1788,7 +1789,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti |
|
|
|
return GE_GRAPH_OPTIONS_INVALID); |
|
|
|
|
|
|
|
// ge.graphType |
|
|
|
ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); |
|
|
|
ret = ParseTrainGraphFlag(options_.train_graph_flag); |
|
|
|
GE_IF_BOOL_EXEC(ret != SUCCESS, |
|
|
|
GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid"); |
|
|
|
return GE_GRAPH_OPTIONS_INVALID); |
|
|
@@ -1833,19 +1834,17 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag) { |
|
|
|
std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance(); |
|
|
|
if (ge_instance_ptr == nullptr) { |
|
|
|
GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); |
|
|
|
train_flag = false; |
|
|
|
} else if (!ge_instance_ptr->isTrainMode()) { |
|
|
|
train_flag = false; |
|
|
|
} else { // ge_instance_ptr->isTrainMode() is true |
|
|
|
train_flag = true; |
|
|
|
if (!run_flag) { |
|
|
|
GELOGW("Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); |
|
|
|
// OPTION_GRAPH_RUN_MODE is supposed to be a session-level option, but it used to be set to global-level in the past. |
|
|
|
// If can not parse from session, it can parse from global by GetContext(). |
|
|
|
Status GraphManager::ParseTrainGraphFlag(bool &train_flag) { |
|
|
|
train_flag = false; |
|
|
|
string run_mode; |
|
|
|
if (GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == SUCCESS && !run_mode.empty()) { |
|
|
|
if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, kBase)) >= TRAIN) { |
|
|
|
train_flag = true; |
|
|
|
} |
|
|
|
} |
|
|
|
GELOGI("Is train flag: %d.", train_flag); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|