| @@ -154,7 +154,7 @@ bool HcclTask::SetSecondaryStream() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream); | stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream); | ||||
| GE_IF_BOOL_EXEC(stream == nullptr, return false); | |||||
| GE_RT_FALSE_CHECK_NOTNULL(stream); | |||||
| secondary_stream_vec[index] = stream; | secondary_stream_vec[index] = stream; | ||||
| } | } | ||||
| @@ -199,6 +199,24 @@ void ClearOption(NamesToPass names_to_pass) { | |||||
| name_to_pass.second->ClearOptions(); | name_to_pass.second->ClearOptions(); | ||||
| } | } | ||||
| } | } | ||||
| bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) { | |||||
| if (node == nullptr) { | |||||
| GELOGW("node is null"); | |||||
| return false; | |||||
| } | |||||
| if (during_pass_node_set.nodes_deleted.count(node) > 0) { | |||||
| GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| if (during_pass_node_set.nodes_suspend.count(node) > 0) { | |||||
| GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
| node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map) { | Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector<int> &io_map) { | ||||
| @@ -277,17 +295,9 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
| nodes.pop_front(); | nodes.pop_front(); | ||||
| (void)during_pass_node_set.nodes_re_pass.erase(node); | (void)during_pass_node_set.nodes_re_pass.erase(node); | ||||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGW("node is null"); continue); | |||||
| if (during_pass_node_set.nodes_deleted.count(node) > 0) { | |||||
| GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| if (during_pass_node_set.nodes_suspend.count(node) > 0) { | |||||
| GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
| node->GetName().c_str()); | |||||
| if (!CheckNode(node, during_pass_node_set)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); | AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); | ||||
| auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | ||||
| @@ -333,8 +343,11 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
| during_pass_node_set.nodes_last.clear(); | during_pass_node_set.nodes_last.clear(); | ||||
| } while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); | } while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); | ||||
| GE_IF_BOOL_EXEC(re_pass_times == kMaxRePassTimes, GELOGW("re_pass_times should not come to %d", kMaxRePassTimes)); | |||||
| if (re_pass_times == kMaxRePassTimes) { | |||||
| GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); | |||||
| } | |||||
| GELOGD("All passes runs end"); | GELOGD("All passes runs end"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { | Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { | ||||
| @@ -41,7 +41,9 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||||
| bool graph_change = false; | bool graph_change = false; | ||||
| // 1. Add FP/BP flow ctrl (big cycle) | // 1. Add FP/BP flow ctrl (big cycle) | ||||
| for (auto &node : compute_graph->GetDirectNode()) { | for (auto &node : compute_graph->GetDirectNode()) { | ||||
| GE_IF_BOOL_EXEC(node == nullptr, continue); | |||||
| if (node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | ||||
| uint32_t true_stream_id = 0; | uint32_t true_stream_id = 0; | ||||
| bool is_found = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, true_stream_id); | bool is_found = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, true_stream_id); | ||||
| @@ -63,12 +65,14 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||||
| // 2. Add special node flow ctrl. eg, IteratorGetNext. (small cycle) | // 2. Add special node flow ctrl. eg, IteratorGetNext. (small cycle) | ||||
| // NOTE: Small cycle share the variables with big cycle. | // NOTE: Small cycle share the variables with big cycle. | ||||
| for (auto &node : compute_graph->GetDirectNode()) { | for (auto &node : compute_graph->GetDirectNode()) { | ||||
| GE_IF_BOOL_EXEC(node == nullptr, continue); | |||||
| if (node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | ||||
| bool need_cycle_flag = false; | bool need_cycle_flag = false; | ||||
| bool is_found = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, need_cycle_flag); | |||||
| (void)AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, need_cycle_flag); | |||||
| // small cycle flag is need_stream_cycle_event == true | // small cycle flag is need_stream_cycle_event == true | ||||
| if (is_found && need_cycle_flag) { | |||||
| if (need_cycle_flag) { | |||||
| Status ret = AddSpecialNodeIteratorCtrl(compute_graph, node); | Status ret = AddSpecialNodeIteratorCtrl(compute_graph, node); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "[Add][SpecialNodeIteratorCtrl] failed, node:%s, graph:%s.", | GELOGE(ret, "[Add][SpecialNodeIteratorCtrl] failed, node:%s, graph:%s.", | ||||
| @@ -1475,9 +1475,11 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input, | |||||
| GeTensorDesc desc(user_input[index].GetTensorDesc()); | GeTensorDesc desc(user_input[index].GetTensorDesc()); | ||||
| // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. | // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. | ||||
| auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); | auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER); | ||||
| GE_CHK_STATUS_RET(CheckInternalFormat(input_node, desc, tune_flag), "[Check][InternalFormat] on %s failed.", | |||||
| op->GetName().c_str()); | |||||
| ret = CheckInternalFormat(input_node, desc, tune_flag); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| auto data_type = desc.GetDataType(); | auto data_type = desc.GetDataType(); | ||||
| uint32_t length = 1; | uint32_t length = 1; | ||||
| bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | ||||