Browse Source

!1775 run_flag

From: @dimitri_rose
Reviewed-by: 
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
83421b6c16
5 changed files with 18 additions and 44 deletions
  1. +1
    -1
      ge/common/profiling/ge_profiling.cc
  2. +11
    -24
      ge/graph/manager/graph_manager.cc
  3. +1
    -1
      ge/graph/manager/graph_manager.h
  4. +1
    -1
      inc/framework/common/profiling/ge_profiling.h
  5. +4
    -17
      tests/ut/ge/graph/manager/graph_manager_unittest.cc

+ 1
- 1
ge/common/profiling/ge_profiling.cc View File

@@ -216,6 +216,6 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le
return ge::SUCCESS;
}

GE_FUNC_VISIBILITY ge::Status ProSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) {
GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) {
return ge::SUCCESS;
}

+ 11
- 24
ge/graph/manager/graph_manager.cc View File

@@ -120,7 +120,6 @@ const char *const kCheckPointForGetVar = "CheckPointGraphForGetVar";
const char *const kCheckPointGraph = "checkpoint_graph";
const char *const kVectorEngine = "VectorEngine";
const char *const kAIcoreEngine = "AIcoreEngine";
const char *const kRunFlagOffline = "0";
const int32_t kDynamicDimsTypeIsGetNext = 0;
const int32_t kDynamicDimsTypeIsData = 1;
const char *const kGetNextName = "IteratorV2";
@@ -1789,8 +1788,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti
return GE_GRAPH_OPTIONS_INVALID);

// ge.graphType
ret =
ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag);
ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag);
GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid");
return GE_GRAPH_OPTIONS_INVALID);
@@ -2436,6 +2434,8 @@ Status GraphManager::RemoveIsolatedConstInThisGraph(ge::ComputeGraphPtr &compute
continue;
}
if (n->GetOpDesc()->GetType() == CONSTANT || n->GetOpDesc()->GetType() == CONSTANTOP) {
// reset const type depend on train_flag
options_.train_graph_flag ? n->GetOpDesc()->SetType(CONSTANTOP) : n->GetOpDesc()->SetType(CONSTANT);
if (n->GetOutAllNodes().empty() && n->GetInAllNodes().empty()) {
// it is an isolated constant, just remove it
if (GraphUtils::RemoveJustNode(compute_graph, n) != GRAPH_SUCCESS) {
@@ -2762,35 +2762,22 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) {
"Please pay attention to it.");
}

GE_CHK_STATUS_RET(ChangeConstType(compute_graph));
ChangeConstTypeWhenTraining(compute_graph);

GELOGI("End optimize after merge sub graph.");
return SUCCESS;
}

Status GraphManager::ChangeConstType(const ComputeGraphPtr &compute_graph) {
// run_flag off means offline, on means online
string run_flag;
(void)ge::GetContext().GetOption(ge::RUN_FLAG, run_flag);
// The constant for online is CONSTANTOP, and is CONSTANT for offline. They will be unified in future.
if (run_flag == kRunFlagOffline) {
GELOGI("Offline mode, change all Constant to Const.");
} else {
GELOGI("Online mode, change all Const to Constant.");
}
for (NodePtr &n : compute_graph->GetAllNodes()) {
GE_CHECK_NOTNULL(n);
if (n->GetType() == CONSTANT || n->GetType() == CONSTANTOP) {
auto op_desc = n->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (run_flag == kRunFlagOffline) {
op_desc->SetType(CONSTANT);
} else {
op_desc->SetType(CONSTANTOP);
void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph) {
// The constant for train is CONSTANTOP, and is CONSTANT for inference. They will be unified in future.
if (options_.train_graph_flag) {
for (NodePtr &n : compute_graph->GetAllNodes()) {
// This can ensure that n is not a null pointer
if (n->GetOpDesc()->GetType() == CONSTANT) {
n->GetOpDesc()->SetType(CONSTANTOP);
}
}
}
return SUCCESS;
}

Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) {


+ 1
- 1
ge/graph/manager/graph_manager.h View File

@@ -375,7 +375,7 @@ class GraphManager {
static void ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_node, RunAsyncCallback callback,
Status ret, const string &log);

Status ChangeConstType(const ComputeGraphPtr &compute_graph);
void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph);

Status PreRunOptimizeOriginalGraph(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs,
ge::ComputeGraphPtr &compute_graph, uint64_t session_id);


+ 1
- 1
inc/framework/common/profiling/ge_profiling.h View File

@@ -43,6 +43,6 @@ GE_FUNC_VISIBILITY ge::Status RegProfCtrlCallback(MsprofCtrlCallback func);
GE_FUNC_VISIBILITY ge::Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func);
GE_FUNC_VISIBILITY ge::Status RegProfReporterCallback(MsprofReporterCallback func);
GE_FUNC_VISIBILITY ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len);
GE_FUNC_VISIBILITY ge::Status ProSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream);
GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream);

#endif // INC_FRAMEWORK_COMMON_GE_PROFILING_H_

+ 4
- 17
tests/ut/ge/graph/manager/graph_manager_unittest.cc View File

@@ -579,29 +579,16 @@ TEST_F(UtestGraphManagerTest, test_prerunthread_failed_2) {
// }

TEST_F(UtestGraphManagerTest, ChangeAndDeleteConst_success) {
std::map<string, string> options_map;
options_map.insert({ge::RUN_FLAG, "0"});
ge::GetThreadLocalContext().SetGraphOption(options_map);

GraphId graph_id = 1;
GraphManager graph_manager;
graph_manager.options_.train_graph_flag = true;

auto graph = CreateGraphWithIsolatedConst();
Status status = graph_manager.ChangeConstType(graph);
EXPECT_EQ(status, ge::SUCCESS);
auto constant1 = graph->FindFirstNodeMatchType("Constant");
EXPECT_EQ(constant1, nullptr);

options_map.clear();
options_map.insert({ge::RUN_FLAG, "1"});
ge::GetThreadLocalContext().SetGraphOption(options_map);
status = graph_manager.ChangeConstType(graph);
EXPECT_EQ(status, ge::SUCCESS);
constant1 = graph->FindFirstNodeMatchType("Constant");
EXPECT_NE(constant1, nullptr);
graph_manager.ChangeConstTypeWhenTraining(graph);
auto const1 = graph->FindFirstNodeMatchType("Const");
EXPECT_EQ(const1, nullptr);

status = graph_manager.RemoveIsolatedConstInThisGraph(graph);
Status status = graph_manager.RemoveIsolatedConstInThisGraph(graph);
EXPECT_EQ(status, ge::SUCCESS);
auto all_nodes = graph->GetDirectNode();
EXPECT_EQ(all_nodes.size(), 3);


Loading…
Cancel
Save