Browse Source

modify

tags/v1.3.0
wjm 4 years ago
parent
commit
67de533b5a
4 changed files with 37 additions and 18 deletions
  1. +1
    -1
      ge/ge_runtime/task/hccl_task.cc
  2. +23
    -10
      ge/graph/passes/base_pass.cc
  3. +8
    -4
      ge/graph/passes/flow_ctrl_pass.cc
  4. +5
    -3
      ge/graph/preprocess/graph_preprocess.cc

+ 1
- 1
ge/ge_runtime/task/hccl_task.cc View File

@@ -154,7 +154,7 @@ bool HcclTask::SetSecondaryStream() {
return false; return false;
} }
stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream); stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream);
GE_IF_BOOL_EXEC(stream == nullptr, return false);
GE_RT_FALSE_CHECK_NOTNULL(stream);


secondary_stream_vec[index] = stream; secondary_stream_vec[index] = stream;
} }


+ 23
- 10
ge/graph/passes/base_pass.cc View File

@@ -199,6 +199,24 @@ void ClearOption(NamesToPass names_to_pass) {
name_to_pass.second->ClearOptions(); name_to_pass.second->ClearOptions();
} }
} }

bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) {
if (node == nullptr) {
GELOGW("node is null");
return false;
}
if (during_pass_node_set.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
return false;
}
if (during_pass_node_set.nodes_suspend.count(node) > 0) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
return false;
}

return true;
}
} // namespace } // namespace


Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map) { Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map) {
@@ -277,17 +295,9 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
nodes.pop_front(); nodes.pop_front();


(void)during_pass_node_set.nodes_re_pass.erase(node); (void)during_pass_node_set.nodes_re_pass.erase(node);
GE_IF_BOOL_EXEC(node == nullptr, GELOGW("node is null"); continue);
if (during_pass_node_set.nodes_deleted.count(node) > 0) {
GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str());
continue;
}
if (during_pass_node_set.nodes_suspend.count(node) > 0) {
GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.",
node->GetName().c_str());
if (!CheckNode(node, during_pass_node_set)) {
continue; continue;
} }

AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set);


auto ret = RunPasses(node, names_to_passes, during_pass_node_set); auto ret = RunPasses(node, names_to_passes, during_pass_node_set);
@@ -333,8 +343,11 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
during_pass_node_set.nodes_last.clear(); during_pass_node_set.nodes_last.clear();
} while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); } while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes);


GE_IF_BOOL_EXEC(re_pass_times == kMaxRePassTimes, GELOGW("re_pass_times should not come to %d", kMaxRePassTimes));
if (re_pass_times == kMaxRePassTimes) {
GELOGW("re_pass_times should not come to %d", kMaxRePassTimes);
}
GELOGD("All passes runs end"); GELOGD("All passes runs end");

return SUCCESS; return SUCCESS;
} }
Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) {


+ 8
- 4
ge/graph/passes/flow_ctrl_pass.cc View File

@@ -41,7 +41,9 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) {
bool graph_change = false; bool graph_change = false;
// 1. Add FP/BP flow ctrl (big cycle) // 1. Add FP/BP flow ctrl (big cycle)
for (auto &node : compute_graph->GetDirectNode()) { for (auto &node : compute_graph->GetDirectNode()) {
GE_IF_BOOL_EXEC(node == nullptr, continue);
if (node == nullptr) {
continue;
}
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
uint32_t true_stream_id = 0; uint32_t true_stream_id = 0;
bool is_found = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, true_stream_id); bool is_found = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, true_stream_id);
@@ -63,12 +65,14 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) {
// 2. Add special node flow ctrl. eg, IteratorGetNext. (small cycle) // 2. Add special node flow ctrl. eg, IteratorGetNext. (small cycle)
// NOTE: Small cycle share the variables with big cycle. // NOTE: Small cycle share the variables with big cycle.
for (auto &node : compute_graph->GetDirectNode()) { for (auto &node : compute_graph->GetDirectNode()) {
GE_IF_BOOL_EXEC(node == nullptr, continue);
if (node == nullptr) {
continue;
}
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
bool need_cycle_flag = false; bool need_cycle_flag = false;
bool is_found = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, need_cycle_flag);
(void)AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, need_cycle_flag);
// small cycle flag is need_stream_cycle_event == true // small cycle flag is need_stream_cycle_event == true
if (is_found && need_cycle_flag) {
if (need_cycle_flag) {
Status ret = AddSpecialNodeIteratorCtrl(compute_graph, node); Status ret = AddSpecialNodeIteratorCtrl(compute_graph, node);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[Add][SpecialNodeIteratorCtrl] failed, node:%s, graph:%s.", GELOGE(ret, "[Add][SpecialNodeIteratorCtrl] failed, node:%s, graph:%s.",


+ 5
- 3
ge/graph/preprocess/graph_preprocess.cc View File

@@ -1475,9 +1475,11 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input,
GeTensorDesc desc(user_input[index].GetTensorDesc()); GeTensorDesc desc(user_input[index].GetTensorDesc());
// data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM.
auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER);
GE_CHK_STATUS_RET(CheckInternalFormat(input_node, desc, tune_flag), "[Check][InternalFormat] on %s failed.",
op->GetName().c_str());

ret = CheckInternalFormat(input_node, desc, tune_flag);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str());
return ret;
}
auto data_type = desc.GetDataType(); auto data_type = desc.GetDataType();
uint32_t length = 1; uint32_t length = 1;
bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); bool type_ret = TypeUtils::GetDataTypeLength(data_type, length);


Loading…
Cancel
Save