| @@ -24,11 +24,7 @@ Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { | |||
| FindNodes(graph); | |||
| for (const auto &node : need_label_nodes_) { | |||
| OpDescPtr op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { | |||
| GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); | |||
| } | |||
| GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); | |||
| } | |||
| GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); | |||
| @@ -83,11 +79,15 @@ void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { | |||
| /// | |||
| Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | |||
| std::string stream_label; | |||
| if (AttachFlag(node, stream_label) != SUCCESS) { | |||
| GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| std::unordered_set<NodePtr> branch_nodes; | |||
| std::unordered_set<NodePtr> visited; | |||
| std::stack<NodePtr> nodes; | |||
| nodes.push(node); | |||
| static const std::set<std::string> end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; | |||
| while (!nodes.empty()) { | |||
| NodePtr cur_node = nodes.top(); | |||
| @@ -95,10 +95,7 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | |||
| if (visited.count(cur_node) > 0) { | |||
| continue; | |||
| } | |||
| if (AttachFlag(cur_node, stream_label) != SUCCESS) { | |||
| GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| const std::string &type = cur_node->GetType(); | |||
| for (const auto &out_node : cur_node->GetOutAllNodes()) { | |||
| @@ -115,10 +112,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | |||
| visited.insert(cur_node); | |||
| } | |||
| if (node->GetType() == STREAMSWITCH) { | |||
| GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); | |||
| } | |||
| for (const NodePtr &tmp_node : branch_nodes) { | |||
| GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); | |||
| GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); | |||
| @@ -148,11 +141,10 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea | |||
| GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, | |||
| "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); | |||
| stream_label += (value ? "_t" : "_f"); | |||
| GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); | |||
| } else if (type == STREAMMERGE) { | |||
| stream_label = node->GetName(); | |||
| GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); | |||
| } else if ((type == EXIT) || (type == REFEXIT)) { | |||
| GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); | |||
| } | |||
| return SUCCESS; | |||
| @@ -166,12 +158,13 @@ Status AttachStreamLabelPass::UpdateEnterNode() { | |||
| std::unordered_map<NodePtr, std::vector<NodePtr>> enter_active_map; | |||
| for (const auto &enter_node : enter_nodes_) { | |||
| for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { | |||
| if (out_ctrl_node->GetType() == STREAMACTIVE) { | |||
| if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { | |||
| enter_active_map[out_ctrl_node] = {enter_node}; | |||
| } else { | |||
| enter_active_map[out_ctrl_node].emplace_back(enter_node); | |||
| } | |||
| if (out_ctrl_node->GetType() != STREAMACTIVE) { | |||
| continue; | |||
| } | |||
| if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { | |||
| enter_active_map[out_ctrl_node] = {enter_node}; | |||
| } else { | |||
| enter_active_map[out_ctrl_node].emplace_back(enter_node); | |||
| } | |||
| } | |||
| } | |||
| @@ -226,9 +219,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||
| std::string stream_label; | |||
| GE_CHECK_NOTNULL(active_node); | |||
| (void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); | |||
| if (stream_label.empty()) { | |||
| GELOGW("stream_label of enter_active & enter_nodes is empty."); | |||
| GELOGD("stream_label of enter_active & enter_nodes is empty."); | |||
| return SUCCESS; | |||
| } | |||
| @@ -238,7 +230,6 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||
| GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | |||
| } | |||
| } | |||
| GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed."); | |||
| return SUCCESS; | |||
| } | |||
| @@ -37,6 +37,12 @@ Status CondRemovePass::Run(NodePtr &node) { | |||
| OutDataAnchorPtr cond_out_anchor = nullptr; | |||
| InDataAnchorPtr cond_in_anchor = nullptr; | |||
| Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); | |||
| if (ret == NOT_CHANGED) { | |||
| return SUCCESS; | |||
| } else if (ret != SUCCESS) { | |||
| GELOGE(FAILED, "Get cond_info for node %s failed.", node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| int32_t cond_index = 0; | |||
| GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str()); | |||
| bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index); | |||
| @@ -322,11 +328,11 @@ Status CondRemovePass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, | |||
| std::string type = node->GetType(); | |||
| if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) { | |||
| if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { | |||
| GELOGE(FAILED, "Get cond_info for if node failed."); | |||
| GELOGE(FAILED, "Get cond_info for if/case node failed."); | |||
| return FAILED; | |||
| } | |||
| } else { | |||
| GELOGI("no need cond_pass for node %s.", node->GetName().c_str()); | |||
| GELOGI("no need cond_remove_pass for node %s.", node->GetName().c_str()); | |||
| return NOT_CHANGED; | |||
| } | |||
| @@ -16,6 +16,7 @@ | |||
| #include "graph/passes/enter_pass.h" | |||
| #include "debug/ge_attr_define.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "framework/common/debug/log.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| @@ -72,33 +73,25 @@ Status EnterPass::Run(NodePtr &node) { | |||
| } | |||
| Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { | |||
| auto out_nodes_of_in_node = in_node->GetOutAllNodes(); | |||
| if (out_nodes_of_in_node.size() != kOutNodesNum) { | |||
| if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty() { | |||
| return SUCCESS; | |||
| } | |||
| if (!node->GetOutControlNodes().empty()) { | |||
| bool is_constant_flag = true; | |||
| (void)AttrUtils::GetBool(node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant_flag); | |||
| if (!is_constant_flag) { | |||
| return SUCCESS; | |||
| } | |||
| for (const auto &out_node : node->GetOutDataNodes()) { | |||
| GE_CHECK_NOTNULL(out_node); | |||
| if (out_node->GetType() == MERGE) { | |||
| return SUCCESS; | |||
| } | |||
| } | |||
| GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); | |||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); | |||
| auto out_data_anchor = node->GetOutDataAnchor(0); | |||
| const auto &out_data_anchor = node->GetOutDataAnchor(0); | |||
| GE_CHECK_NOTNULL(out_data_anchor); | |||
| for (auto peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
| for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
| GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); | |||
| GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); | |||
| } | |||
| auto graph = node->GetOwnerComputeGraph(); | |||
| GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, node)) | |||
| GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)); | |||
| AddNodeDeleted(node); | |||
| AddRePassNodesWithInOut(in_node); | |||
| return SUCCESS; | |||
| @@ -137,7 +137,7 @@ Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &n | |||
| for_info.ctrl_inputs = std::move(ctrl_inputs); | |||
| for_info.ctrl_outputs = std::move(ctrl_outputs); | |||
| GELOGI("Build for_info for node %s succ.", node->GetName().c_str()); | |||
| GELOGI("Build for_info for node %s success.", node->GetName().c_str()); | |||
| return SUCCESS; | |||
| } | |||
| @@ -159,13 +159,7 @@ OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index | |||
| return nullptr; | |||
| } | |||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (peer_out_anchor == nullptr) { | |||
| GELOGE(FAILED, "FindInputWithIndex %s:%u failed: peer_out_anchor is NULL.", node->GetName().c_str(), index); | |||
| return nullptr; | |||
| } | |||
| return peer_out_anchor; | |||
| return in_data_anchor->GetPeerOutAnchor(); | |||
| } | |||
| /// | |||
| @@ -186,20 +180,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc | |||
| uint32_t input_data_num = node->GetAllInDataAnchorsSize(); | |||
| for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) { | |||
| InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); | |||
| if (in_data_anchor == nullptr) { | |||
| GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index); | |||
| return FAILED; | |||
| } | |||
| GE_IF_BOOL_EXEC(in_data_anchor->GetPeerOutAnchor() == nullptr, | |||
| GELOGW("Get null input by index %d from node %s ", | |||
| in_data_anchor->GetIdx(), node->GetName().c_str()); | |||
| continue); | |||
| GE_CHECK_NOTNULL(in_data_anchor); | |||
| data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor()); | |||
| } | |||
| for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
| for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
| std::vector<ge::InDataAnchorPtr> peer_in_data_anchors; | |||
| for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
| for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
| peer_in_data_anchors.emplace_back(peer_in_data_anchor); | |||
| } | |||
| data_outputs.emplace_back(peer_in_data_anchors); | |||
| @@ -207,13 +194,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc | |||
| InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor(); | |||
| GE_CHECK_NOTNULL(in_ctrl_anchor); | |||
| for (auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { | |||
| for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { | |||
| ctrl_inputs.emplace_back(peer_out_ctrl_anchor); | |||
| } | |||
| OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor(); | |||
| GE_CHECK_NOTNULL(out_ctrl_anchor); | |||
| for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
| for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
| ctrl_outputs.emplace_back(peer_in_ctrl_anchor); | |||
| } | |||
| @@ -21,16 +21,12 @@ | |||
| #include <vector> | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "common/ge/ge_util.h" | |||
| #include "graph/common/omg_util.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/passes/pass_utils.h" | |||
| using domi::PARAM_INVALID; | |||
| using domi::SUCCESS; | |||
| namespace ge { | |||
| const int kValueIndexOutputIndex = 1; | |||
| @@ -47,13 +43,12 @@ Status MergePass::Run(NodePtr &node) { | |||
| return SUCCESS; | |||
| } | |||
| auto out_data_anchors = node->GetAllOutDataAnchors(); | |||
| if (out_data_anchors.empty()) { | |||
| if (node->GetAllOutDataAnchors().empty()) { | |||
| GELOGE(PARAM_INVALID, "[%s] Merge node output anchor is empty", node->GetName().c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| auto in_data_nodes = node->GetInDataNodes(); | |||
| const auto in_data_nodes = node->GetInDataNodes(); | |||
| switch (in_data_nodes.size()) { | |||
| case 0: { | |||
| /// Case A: input_count = 0, the output of merge node is inactive as well | |||
| @@ -22,9 +22,6 @@ | |||
| #include "graph/common/omg_util.h" | |||
| #include "graph/utils/type_utils.h" | |||
| using std::string; | |||
| using std::vector; | |||
| namespace ge { | |||
| Status MultiBatchPass::Run(ComputeGraphPtr graph) { | |||
| GELOGD("MultiBatchPass Enter"); | |||
| @@ -53,7 +50,7 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { | |||
| return FAILED; | |||
| } | |||
| std::vector<std::vector<int64_t>> batch_shape; | |||
| vector<vector<int64_t>> combined_batch; | |||
| std::vector<std::vector<int64_t>> combined_batch; | |||
| if (!CheckSwitchN(batch_shape, combined_batch)) { | |||
| GELOGE(FAILED, "CheckSwitchN failed."); | |||
| return FAILED; | |||
| @@ -104,6 +101,7 @@ Status MultiBatchPass::ClearStatus() { | |||
| /// | |||
| Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) { | |||
| const auto &func_desc = case_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(func_desc); | |||
| if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { | |||
| GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str()); | |||
| return SUCCESS; | |||
| @@ -114,7 +112,7 @@ Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr | |||
| const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]); | |||
| GE_CHECK_NOTNULL(subgraph); | |||
| const string batch_label = "Batch_" + std::to_string(i); | |||
| const std::string batch_label = "Batch_" + std::to_string(i); | |||
| for (const auto &node : subgraph->GetDirectNode()) { | |||
| (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); | |||
| } | |||
| @@ -139,12 +137,12 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor | |||
| continue; | |||
| } | |||
| InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||
| const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||
| if (in_data_anchor == nullptr) { | |||
| GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor(); | |||
| const auto &pred_input = in_data_anchor->GetPeerOutAnchor(); | |||
| if (pred_input == nullptr) { | |||
| GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); | |||
| return FAILED; | |||
| @@ -178,12 +176,10 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor | |||
| /// @return Status | |||
| /// | |||
| Status MultiBatchPass::GetDynamicType() { | |||
| for (const auto &switchn : switch_n_nodes_) { | |||
| auto switchn_desc = switchn->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(switchn_desc); | |||
| for (const auto &switch_n : switch_n_nodes_) { | |||
| int32_t dynamic_type = static_cast<int32_t>(FIXED); | |||
| if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) { | |||
| GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str()); | |||
| if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) { | |||
| GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| if (dynamic_type == static_cast<int32_t>(FIXED)) { | |||
| @@ -191,7 +187,7 @@ Status MultiBatchPass::GetDynamicType() { | |||
| return FAILED; | |||
| } | |||
| if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) { | |||
| GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.", | |||
| GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.", | |||
| dynamic_type, dynamic_type_); | |||
| return FAILED; | |||
| } | |||
| @@ -212,21 +208,19 @@ Status MultiBatchPass::GetDynamicType() { | |||
| Status MultiBatchPass::GetUserDesignateShape() { | |||
| data_name_order_.clear(); | |||
| bool first_check = true; | |||
| for (const auto &switchn : switch_n_nodes_) { | |||
| auto switchn_desc = switchn->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(switchn_desc); | |||
| vector<string> cur_switchn_data_name_order; | |||
| if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) { | |||
| GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str()); | |||
| for (const auto &switch_n : switch_n_nodes_) { | |||
| std::vector<std::string> cur_data_name_order; | |||
| if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) { | |||
| GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| if (first_check) { | |||
| data_name_order_ = cur_switchn_data_name_order; | |||
| data_name_order_ = cur_data_name_order; | |||
| first_check = false; | |||
| } else { | |||
| if (data_name_order_ != cur_switchn_data_name_order) { | |||
| if (data_name_order_ != cur_data_name_order) { | |||
| GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.", | |||
| switchn->GetName().c_str()); | |||
| switch_n->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| @@ -245,7 +239,8 @@ Status MultiBatchPass::GetUserDesignateShape() { | |||
| /// @param [out] combined_batch | |||
| /// @return bool | |||
| /// | |||
| bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<vector<int64_t>> &combined_batch) { | |||
| bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape, | |||
| std::vector<std::vector<int64_t>> &combined_batch) { | |||
| // Check if output_num of different SwitchN is same | |||
| uint32_t batch_num = 0; | |||
| for (const NodePtr &node : switch_n_nodes_) { | |||
| @@ -281,7 +276,8 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v | |||
| } | |||
| size_t tmp_combined_dim_num = combined_batch[i].size(); | |||
| if (combined_dim_num != tmp_combined_dim_num) { | |||
| GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num); | |||
| GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", | |||
| combined_dim_num, i, tmp_combined_dim_num); | |||
| return false; | |||
| } | |||
| } | |||
| @@ -296,11 +292,11 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v | |||
| /// @param [out] combined_batch | |||
| /// @return bool | |||
| /// | |||
| bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &batch_shape, | |||
| vector<vector<int64_t>> &combined_batch) { | |||
| bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector<std::vector<int64_t>> &batch_shape, | |||
| std::vector<std::vector<int64_t>> &combined_batch) { | |||
| // Check if output_shape of different SwitchN is same | |||
| vector<vector<int64_t>> idx_batch_shape; | |||
| vector<vector<int64_t>> idx_combined_batch; | |||
| std::vector<std::vector<int64_t>> idx_batch_shape; | |||
| std::vector<std::vector<int64_t>> idx_combined_batch; | |||
| for (uint32_t i = 0; i < batch_num; i++) { | |||
| idx_batch_shape.clear(); | |||
| idx_combined_batch.clear(); | |||
| @@ -310,7 +306,7 @@ bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &b | |||
| GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); | |||
| return false; | |||
| } | |||
| vector<int64_t> output_dims; | |||
| std::vector<int64_t> output_dims; | |||
| if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { | |||
| GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); | |||
| return false; | |||
| @@ -385,8 +381,8 @@ Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { | |||
| /// @return Status | |||
| /// | |||
| Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, | |||
| const vector<vector<int64_t>> &batch_shape, | |||
| const vector<vector<int64_t>> &combined_batch) { | |||
| const std::vector<std::vector<int64_t>> &batch_shape, | |||
| const std::vector<std::vector<int64_t>> &combined_batch) { | |||
| NodePtr pred_value_node = pred_value->GetOwnerNode(); | |||
| // Create SwitchCase node | |||
| const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; | |||
| @@ -429,31 +425,11 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s | |||
| return false; | |||
| } | |||
| size_t num = output_shape.size(); | |||
| size_t dim_num = output_shape[0].size(); | |||
| for (size_t i = 1; i < num; i++) { | |||
| size_t tmp_dim_num = output_shape[i].size(); | |||
| if (dim_num != tmp_dim_num) { | |||
| GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); | |||
| for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) { | |||
| if (output_shape[0] != *iter) { | |||
| return false; | |||
| } | |||
| } | |||
| if (dim_num == 0) { | |||
| return true; | |||
| } | |||
| for (size_t i = 0; i < dim_num; i++) { | |||
| int64_t dim_value = output_shape[0][i]; | |||
| for (size_t j = 1; j < num; j++) { | |||
| int64_t tmp_dim_value = output_shape[j][i]; | |||
| if (dim_value != tmp_dim_value) { | |||
| GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, | |||
| dim_value, j, tmp_dim_value); | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| @@ -468,8 +444,8 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s | |||
| /// | |||
| NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, | |||
| const OutDataAnchorPtr &pred_value, | |||
| const vector<vector<int64_t>> &batch_shape, | |||
| const vector<vector<int64_t>> &combined_batch) { | |||
| const std::vector<std::vector<int64_t>> &batch_shape, | |||
| const std::vector<std::vector<int64_t>> &combined_batch) { | |||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN); | |||
| if (op_desc == nullptr) { | |||
| GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); | |||
| @@ -512,7 +488,7 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const | |||
| GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); | |||
| return nullptr; | |||
| } | |||
| const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); | |||
| const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); | |||
| if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) { | |||
| GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str()); | |||
| return nullptr; | |||
| @@ -72,25 +72,26 @@ Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &gra | |||
| std::unordered_map<NodePtr, std::vector<NodePtr>> cond_switch_map; | |||
| for (const NodePtr &node : graph->GetDirectNode()) { | |||
| GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); | |||
| if ((type == SWITCH) || (type == REFSWITCH)) { | |||
| InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||
| GE_CHECK_NOTNULL(in_cond_anchor); | |||
| OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||
| if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { | |||
| GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| if ((type != SWITCH) && (type != REFSWITCH)) { | |||
| continue; | |||
| } | |||
| InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||
| GE_CHECK_NOTNULL(in_cond_anchor); | |||
| OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||
| if (FindSwitchCondInput(peer_out_anchor) != SUCCESS) { | |||
| GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| NodePtr cond_node = peer_out_anchor->GetOwnerNode(); | |||
| auto iter = cond_switch_map.find(cond_node); | |||
| if (iter == cond_switch_map.end()) { | |||
| cond_switch_map[cond_node] = { node }; | |||
| } else { | |||
| iter->second.emplace_back(node); | |||
| } | |||
| switch_nodes_.emplace_back(node); | |||
| NodePtr cond_node = peer_out_anchor->GetOwnerNode(); | |||
| auto iter = cond_switch_map.find(cond_node); | |||
| if (iter == cond_switch_map.end()) { | |||
| cond_switch_map[cond_node] = { node }; | |||
| } else { | |||
| iter->second.emplace_back(node); | |||
| } | |||
| switch_nodes_.emplace_back(node); | |||
| } | |||
| MarkCycleDependence(cond_switch_map); | |||
| @@ -241,10 +242,6 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou | |||
| if (idx == SWITCH_DATA_INPUT) { | |||
| peer_data_anchor = peer_out_anchor; | |||
| } else { | |||
| if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { | |||
| GELOGE(FAILED, "Find pred_input for switch_node %s failed.", switch_node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| peer_cond_anchor = peer_out_anchor; | |||
| } | |||
| } | |||
| @@ -254,15 +251,14 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou | |||
| /// | |||
| /// @brief Find Switch cond input | |||
| /// @param [in] pass_switch_flag | |||
| /// @param [out] peer_cond_anchor | |||
| /// @return Status | |||
| /// | |||
| Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { | |||
| Status SwitchToStreamSwitchPass::FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor) { | |||
| NodePtr tmp_node = nullptr; | |||
| string type; | |||
| bool need_pass_type = true; | |||
| while (need_pass_type) { | |||
| std::string type; | |||
| bool pass_flag = true; | |||
| while (pass_flag) { | |||
| if (tmp_node == nullptr) { | |||
| tmp_node = peer_cond_anchor->GetOwnerNode(); | |||
| } else { | |||
| @@ -274,7 +270,7 @@ Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutD | |||
| } | |||
| GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); | |||
| need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); | |||
| pass_flag = ((type == SWITCH) || (type == REFSWITCH)); | |||
| } | |||
| return SUCCESS; | |||
| @@ -369,7 +365,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ | |||
| } | |||
| } else { | |||
| int64_t switch_group_id = GetGroupId(stream_switch); | |||
| map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map; | |||
| std::map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map; | |||
| std::list<NodePtr> false_node_list; | |||
| std::list<NodePtr> true_node_list; | |||
| std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list; | |||
| @@ -389,7 +385,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ | |||
| /// @return group_id | |||
| /// | |||
| int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||
| string tailing_optimization_option; | |||
| std::string tailing_optimization_option; | |||
| bool is_tailing_optimization = false; | |||
| if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { | |||
| // "1" means it's True from frontend option | |||
| @@ -400,7 +396,7 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||
| return 0; | |||
| } | |||
| string hccl_group_id; | |||
| std::string hccl_group_id; | |||
| if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { | |||
| GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); | |||
| return 0; | |||
| @@ -432,6 +428,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
| same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | |||
| OutDataAnchorPtr peer_cond_anchor = iter->first; | |||
| GE_CHECK_NOTNULL(peer_cond_anchor); | |||
| NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); | |||
| GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); | |||
| @@ -549,6 +546,7 @@ NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, con | |||
| NodePtr cast_node = graph->AddNode(cast_desc); | |||
| GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); | |||
| // Cast node has and only has one input | |||
| GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); | |||
| return cast_node; | |||
| @@ -614,24 +612,24 @@ Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_no | |||
| return INTERNAL_ERROR; | |||
| } | |||
| for (const NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { | |||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), | |||
| for (const NodePtr &in_ctrl_node : switch_node->GetInControlNodes()) { | |||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), | |||
| "Remove ctl edge failed."); | |||
| GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { | |||
| GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||
| GE_IF_BOOL_EXEC(!in_ctrl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { | |||
| GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||
| "Add ctl edge failed."); | |||
| }); | |||
| GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); | |||
| if (same_cond_switch.count(in_ctl_node) > 0) { | |||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||
| GE_IF_BOOL_EXEC(in_ctrl_node->GetType() != STREAMSWITCH, continue); | |||
| if (same_cond_switch.count(in_ctrl_node) > 0) { | |||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||
| "Remove ctl edge failed."); | |||
| continue; | |||
| } | |||
| auto find_res1 = switch_node_map_.find(in_ctl_node); | |||
| auto find_res1 = switch_node_map_.find(in_ctrl_node); | |||
| GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { | |||
| GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctrl_node->GetName().c_str()); | |||
| return INTERNAL_ERROR; | |||
| }); | |||
| auto find_res2 = find_res1->second.find(orig_switch_name); | |||
| @@ -131,11 +131,10 @@ class SwitchToStreamSwitchPass : public GraphPass { | |||
| /// | |||
| /// @brief Find Switch cond input | |||
| /// @param [in] pass_switch_flag | |||
| /// @param [out] peer_cond_anchor | |||
| /// @return Status | |||
| /// | |||
| Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor); | |||
| Status FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor); | |||
| /// | |||
| /// @brief Create StreamSwitch Node | |||