|
|
@@ -120,7 +120,6 @@ const char *const kCheckPointForGetVar = "CheckPointGraphForGetVar"; |
|
|
|
const char *const kCheckPointGraph = "checkpoint_graph"; |
|
|
|
const char *const kVectorEngine = "VectorEngine"; |
|
|
|
const char *const kAIcoreEngine = "AIcoreEngine"; |
|
|
|
const char *const kRunFlagOffline = "0"; |
|
|
|
const int32_t kDynamicDimsTypeIsGetNext = 0; |
|
|
|
const int32_t kDynamicDimsTypeIsData = 1; |
|
|
|
const char *const kGetNextName = "IteratorV2"; |
|
|
@@ -1789,8 +1788,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti |
|
|
|
return GE_GRAPH_OPTIONS_INVALID); |
|
|
|
|
|
|
|
// ge.graphType |
|
|
|
ret = |
|
|
|
ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); |
|
|
|
ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); |
|
|
|
GE_IF_BOOL_EXEC(ret != SUCCESS, |
|
|
|
GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid"); |
|
|
|
return GE_GRAPH_OPTIONS_INVALID); |
|
|
@@ -2436,6 +2434,8 @@ Status GraphManager::RemoveIsolatedConstInThisGraph(ge::ComputeGraphPtr &compute |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (n->GetOpDesc()->GetType() == CONSTANT || n->GetOpDesc()->GetType() == CONSTANTOP) { |
|
|
|
// reset const type depend on train_flag |
|
|
|
options_.train_graph_flag ? n->GetOpDesc()->SetType(CONSTANTOP) : n->GetOpDesc()->SetType(CONSTANT); |
|
|
|
if (n->GetOutAllNodes().empty() && n->GetInAllNodes().empty()) { |
|
|
|
// it is an isolated constant, just remove it |
|
|
|
if (GraphUtils::RemoveJustNode(compute_graph, n) != GRAPH_SUCCESS) { |
|
|
@@ -2762,35 +2762,22 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { |
|
|
|
"Please pay attention to it."); |
|
|
|
} |
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(ChangeConstType(compute_graph)); |
|
|
|
ChangeConstTypeWhenTraining(compute_graph); |
|
|
|
|
|
|
|
GELOGI("End optimize after merge sub graph."); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status GraphManager::ChangeConstType(const ComputeGraphPtr &compute_graph) { |
|
|
|
// run_flag off means offline, on means online |
|
|
|
string run_flag; |
|
|
|
(void)ge::GetContext().GetOption(ge::RUN_FLAG, run_flag); |
|
|
|
// The constant for online is CONSTANTOP, and is CONSTANT for offline. They will be unified in future. |
|
|
|
if (run_flag == kRunFlagOffline) { |
|
|
|
GELOGI("Offline mode, change all Constant to Const."); |
|
|
|
} else { |
|
|
|
GELOGI("Online mode, change all Const to Constant."); |
|
|
|
} |
|
|
|
for (NodePtr &n : compute_graph->GetAllNodes()) { |
|
|
|
GE_CHECK_NOTNULL(n); |
|
|
|
if (n->GetType() == CONSTANT || n->GetType() == CONSTANTOP) { |
|
|
|
auto op_desc = n->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
if (run_flag == kRunFlagOffline) { |
|
|
|
op_desc->SetType(CONSTANT); |
|
|
|
} else { |
|
|
|
op_desc->SetType(CONSTANTOP); |
|
|
|
void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph) { |
|
|
|
// The constant for train is CONSTANTOP, and is CONSTANT for inference. They will be unified in future. |
|
|
|
if (options_.train_graph_flag) { |
|
|
|
for (NodePtr &n : compute_graph->GetAllNodes()) { |
|
|
|
// This can ensure that n is not a null pointer |
|
|
|
if (n->GetOpDesc()->GetType() == CONSTANT) { |
|
|
|
n->GetOpDesc()->SetType(CONSTANTOP); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { |
|
|
|