From 9909a46e2ad1d91e571a30e541dab29f462e0f9d Mon Sep 17 00:00:00 2001 From: chenyemeng Date: Wed, 2 Dec 2020 14:49:58 +0800 Subject: [PATCH] modify for security check of control pass --- ge/graph/passes/attach_stream_label_pass.cc | 54 +++++------- ge/graph/passes/cond_remove_pass.cc | 10 ++- ge/graph/passes/enter_pass.cc | 25 ++---- ge/graph/passes/for_pass.cc | 27 ++---- ge/graph/passes/merge_pass.cc | 9 +- ge/graph/passes/multi_batch_pass.cc | 88 +++++++------------ .../passes/switch_to_stream_switch_pass.cc | 76 ++++++++-------- .../passes/switch_to_stream_switch_pass.h | 3 +- 8 files changed, 119 insertions(+), 173 deletions(-) diff --git a/ge/graph/passes/attach_stream_label_pass.cc b/ge/graph/passes/attach_stream_label_pass.cc index b04643a4..c0e0f669 100644 --- a/ge/graph/passes/attach_stream_label_pass.cc +++ b/ge/graph/passes/attach_stream_label_pass.cc @@ -24,11 +24,7 @@ Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { FindNodes(graph); for (const auto &node : need_label_nodes_) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { - GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); - } + GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); } GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); @@ -55,13 +51,15 @@ Status AttachStreamLabelPass::ClearStatus() { /// void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { for (const NodePtr &node : graph->GetDirectNode()) { - const std::string &type = node->GetType(); - if (type == STREAMSWITCH) { + const auto &op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + const std::string &type = op_desc->GetType(); + if ((type == STREAMSWITCH) && op_desc->HasAttr(ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG)) { stream_switch_nodes_.emplace_back(node); - } else if (type == STREAMMERGE) { - if ((node->GetOpDesc() != nullptr) && !node->GetOpDesc()->HasAttr(ATTR_NAME_NEXT_ITERATION)) { - need_label_nodes_.emplace_back(node); - } + } else if ((type == STREAMMERGE) && !op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { + need_label_nodes_.emplace_back(node); } else if ((type == ENTER) || (type == REFENTER)) { enter_nodes_.emplace_back(node); } @@ -83,11 +81,15 @@ void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { /// Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { std::string stream_label; + if (AttachFlag(node, stream_label) != SUCCESS) { + GELOGE(FAILED, "Attach flag for node %s failed.", node->GetName().c_str()); + return FAILED; + } + std::unordered_set branch_nodes; std::unordered_set visited; std::stack nodes; nodes.push(node); - static const std::set end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; while (!nodes.empty()) { NodePtr cur_node = nodes.top(); @@ -95,10 +97,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { if (visited.count(cur_node) > 0) { continue; } - if (AttachFlag(cur_node, stream_label) != SUCCESS) { - GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str()); - return FAILED; - } const std::string &type = cur_node->GetType(); for (const auto &out_node : cur_node->GetOutAllNodes()) { @@ -115,10 +113,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { visited.insert(cur_node); } - if (node->GetType() == STREAMSWITCH) { - GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); - } - for (const NodePtr &tmp_node : branch_nodes) { GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); @@ -148,11 +142,10 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); stream_label += (value ? "_t" : "_f"); + GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); } else if (type == STREAMMERGE) { stream_label = node->GetName(); GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); - } else if ((type == EXIT) || (type == REFEXIT)) { - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); } return SUCCESS; @@ -166,12 +159,13 @@ Status AttachStreamLabelPass::UpdateEnterNode() { std::unordered_map> enter_active_map; for (const auto &enter_node : enter_nodes_) { for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { - if (out_ctrl_node->GetType() == STREAMACTIVE) { - if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { - enter_active_map[out_ctrl_node] = {enter_node}; - } else { - enter_active_map[out_ctrl_node].emplace_back(enter_node); - } + if (out_ctrl_node->GetType() != STREAMACTIVE) { + continue; + } + if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { + enter_active_map[out_ctrl_node] = {enter_node}; + } else { + enter_active_map[out_ctrl_node].emplace_back(enter_node); } } } @@ -226,9 +220,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_no std::string stream_label; GE_CHECK_NOTNULL(active_node); (void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); - if (stream_label.empty()) { - GELOGW("stream_label of enter_active & enter_nodes is empty."); + GELOGD("stream_label of enter_active %s is empty.", active_node->GetName().c_str()); return SUCCESS; } @@ -238,7 +231,6 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_no GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); } } - GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed."); return SUCCESS; } diff --git a/ge/graph/passes/cond_remove_pass.cc b/ge/graph/passes/cond_remove_pass.cc index e8d1493f..bf2e1170 100644 --- a/ge/graph/passes/cond_remove_pass.cc +++ b/ge/graph/passes/cond_remove_pass.cc @@ -37,6 +37,12 @@ Status CondRemovePass::Run(NodePtr &node) { OutDataAnchorPtr cond_out_anchor = nullptr; InDataAnchorPtr cond_in_anchor = nullptr; Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); + if (ret == NOT_CHANGED) { + return SUCCESS; + } else if (ret != SUCCESS) { + GELOGE(FAILED, "Get cond_info for node %s failed.", node->GetName().c_str()); + return FAILED; + } int32_t cond_index = 0; GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str()); bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index); @@ -322,11 +328,11 @@ Status CondRemovePass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, std::string type = node->GetType(); if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) { if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { - GELOGE(FAILED, "Get cond_info for if node failed."); + GELOGE(FAILED, "Get cond_info for if/case node failed."); return FAILED; } } else { - GELOGD("no need cond_pass for node %s.", node->GetName().c_str()); + GELOGD("no need cond_remove_pass for node %s.", node->GetName().c_str()); return NOT_CHANGED; } diff --git a/ge/graph/passes/enter_pass.cc b/ge/graph/passes/enter_pass.cc index 206d271c..afeca78f 100644 --- a/ge/graph/passes/enter_pass.cc +++ b/ge/graph/passes/enter_pass.cc @@ -16,6 +16,7 @@ #include "graph/passes/enter_pass.h" +#include "graph/debug/ge_attr_define.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "graph/utils/graph_utils.h" @@ -72,33 +73,25 @@ Status EnterPass::Run(NodePtr &node) { } Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { - auto out_nodes_of_in_node = in_node->GetOutAllNodes(); - if (out_nodes_of_in_node.size() != kOutNodesNum) { + if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { return SUCCESS; } - - if (!node->GetOutControlNodes().empty()) { + bool is_constant_flag = true; + (void)AttrUtils::GetBool(node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant_flag); + if (!is_constant_flag) { return SUCCESS; } - for (const auto &out_node : node->GetOutDataNodes()) { - GE_CHECK_NOTNULL(out_node); - if (out_node->GetType() == MERGE) { - return SUCCESS; - } - } - GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); - auto out_data_anchor = node->GetOutDataAnchor(0); + const auto &out_data_anchor = node->GetOutDataAnchor(0); GE_CHECK_NOTNULL(out_data_anchor); - for (auto peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); } - - auto graph = node->GetOwnerComputeGraph(); - GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, node)) + GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)); + AddNodeDeleted(node); AddRePassNodesWithInOut(in_node); return SUCCESS; diff --git a/ge/graph/passes/for_pass.cc b/ge/graph/passes/for_pass.cc index f3caea35..f5280a36 100644 --- a/ge/graph/passes/for_pass.cc +++ b/ge/graph/passes/for_pass.cc @@ -137,7 +137,7 @@ Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &n for_info.ctrl_inputs = std::move(ctrl_inputs); for_info.ctrl_outputs = std::move(ctrl_outputs); - GELOGI("Build for_info for node %s succ.", node->GetName().c_str()); + GELOGI("Build for_info for node %s success.", node->GetName().c_str()); return SUCCESS; } @@ -159,13 +159,7 @@ OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index return nullptr; } - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - GELOGE(FAILED, "FindInputWithIndex %s:%u failed: peer_out_anchor is NULL.", node->GetName().c_str(), index); - return nullptr; - } - - return peer_out_anchor; + return in_data_anchor->GetPeerOutAnchor(); } /// @@ -186,20 +180,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vectorGetAllInDataAnchorsSize(); for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) { InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); - if (in_data_anchor == nullptr) { - GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index); - return FAILED; - } - GE_IF_BOOL_EXEC(in_data_anchor->GetPeerOutAnchor() == nullptr, - GELOGW("Get null input by index %d from node %s ", - in_data_anchor->GetIdx(), node->GetName().c_str()); - continue); + GE_CHECK_NOTNULL(in_data_anchor); data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor()); } - for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { std::vector peer_in_data_anchors; - for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { peer_in_data_anchors.emplace_back(peer_in_data_anchor); } data_outputs.emplace_back(peer_in_data_anchors); @@ -207,13 +194,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vectorGetInControlAnchor(); GE_CHECK_NOTNULL(in_ctrl_anchor); - for (auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { + for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { ctrl_inputs.emplace_back(peer_out_ctrl_anchor); } OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor(); GE_CHECK_NOTNULL(out_ctrl_anchor); - for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { + for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { ctrl_outputs.emplace_back(peer_in_ctrl_anchor); } diff --git a/ge/graph/passes/merge_pass.cc b/ge/graph/passes/merge_pass.cc index d2340037..80394e7a 100644 --- a/ge/graph/passes/merge_pass.cc +++ b/ge/graph/passes/merge_pass.cc @@ -21,16 +21,12 @@ #include #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" #include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/passes/pass_utils.h" -using domi::PARAM_INVALID; -using domi::SUCCESS; - namespace ge { const int kValueIndexOutputIndex = 1; @@ -47,13 +43,12 @@ Status MergePass::Run(NodePtr &node) { return SUCCESS; } - auto out_data_anchors = node->GetAllOutDataAnchors(); - if (out_data_anchors.empty()) { + if (node->GetAllOutDataAnchors().empty()) { GELOGE(PARAM_INVALID, "[%s] Merge node output anchor is empty", node->GetName().c_str()); return PARAM_INVALID; } - auto in_data_nodes = node->GetInDataNodes(); + const auto &in_data_nodes = node->GetInDataNodes(); switch (in_data_nodes.size()) { case 0: { /// Case A: input_count = 0, the output of merge node is inactive as well diff --git a/ge/graph/passes/multi_batch_pass.cc b/ge/graph/passes/multi_batch_pass.cc index c7034612..74f7e30e 100644 --- a/ge/graph/passes/multi_batch_pass.cc +++ b/ge/graph/passes/multi_batch_pass.cc @@ -22,9 +22,6 @@ #include "graph/common/omg_util.h" #include "graph/utils/type_utils.h" -using std::string; -using std::vector; - namespace ge { Status MultiBatchPass::Run(ComputeGraphPtr graph) { GELOGD("MultiBatchPass Enter"); @@ -53,7 +50,7 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { return FAILED; } std::vector> batch_shape; - vector> combined_batch; + std::vector> combined_batch; if (!CheckSwitchN(batch_shape, combined_batch)) { GELOGE(FAILED, "CheckSwitchN failed."); return FAILED; @@ -104,6 +101,7 @@ Status MultiBatchPass::ClearStatus() { /// Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) { const auto &func_desc = case_node->GetOpDesc(); + GE_CHECK_NOTNULL(func_desc); if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str()); return SUCCESS; @@ -114,7 +112,7 @@ Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]); GE_CHECK_NOTNULL(subgraph); - const string batch_label = "Batch_" + std::to_string(i); + const std::string batch_label = "Batch_" + std::to_string(i); for (const auto &node : subgraph->GetDirectNode()) { (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); } @@ -139,12 +137,12 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor continue; } - InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); + const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); if (in_data_anchor == nullptr) { GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); return FAILED; } - OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor(); + const auto &pred_input = in_data_anchor->GetPeerOutAnchor(); if (pred_input == nullptr) { GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); return FAILED; @@ -178,12 +176,10 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor /// @return Status /// Status MultiBatchPass::GetDynamicType() { - for (const auto &switchn : switch_n_nodes_) { - auto switchn_desc = switchn->GetOpDesc(); - GE_CHECK_NOTNULL(switchn_desc); + for (const auto &switch_n : switch_n_nodes_) { int32_t dynamic_type = static_cast(FIXED); - if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) { - GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str()); + if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) { + GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str()); return FAILED; } if (dynamic_type == static_cast(FIXED)) { @@ -191,7 +187,7 @@ Status MultiBatchPass::GetDynamicType() { return FAILED; } if (dynamic_type_ != static_cast(FIXED) && dynamic_type_ != dynamic_type) { - GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.", + GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.", dynamic_type, dynamic_type_); return FAILED; } @@ -212,21 +208,19 @@ Status MultiBatchPass::GetDynamicType() { Status MultiBatchPass::GetUserDesignateShape() { data_name_order_.clear(); bool first_check = true; - for (const auto &switchn : switch_n_nodes_) { - auto switchn_desc = switchn->GetOpDesc(); - GE_CHECK_NOTNULL(switchn_desc); - vector cur_switchn_data_name_order; - if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) { - GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str()); + for (const auto &switch_n : switch_n_nodes_) { + std::vector cur_data_name_order; + if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) { + GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str()); return FAILED; } if (first_check) { - data_name_order_ = cur_switchn_data_name_order; + data_name_order_ = cur_data_name_order; first_check = false; } else { - if (data_name_order_ != cur_switchn_data_name_order) { + if (data_name_order_ != cur_data_name_order) { GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.", - switchn->GetName().c_str()); + switch_n->GetName().c_str()); return FAILED; } } @@ -245,7 +239,8 @@ Status MultiBatchPass::GetUserDesignateShape() { /// @param [out] combined_batch /// @return bool /// -bool MultiBatchPass::CheckSwitchN(vector> &batch_shape, vector> &combined_batch) { +bool MultiBatchPass::CheckSwitchN(std::vector> &batch_shape, + std::vector> &combined_batch) { // Check if output_num of different SwitchN is same uint32_t batch_num = 0; for (const NodePtr &node : switch_n_nodes_) { @@ -281,7 +276,8 @@ bool MultiBatchPass::CheckSwitchN(vector> &batch_shape, vector> &batch_shape, vector> &batch_shape, - vector> &combined_batch) { +bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector> &batch_shape, + std::vector> &combined_batch) { // Check if output_shape of different SwitchN is same - vector> idx_batch_shape; - vector> idx_combined_batch; + std::vector> idx_batch_shape; + std::vector> idx_combined_batch; for (uint32_t i = 0; i < batch_num; i++) { idx_batch_shape.clear(); idx_combined_batch.clear(); @@ -310,7 +306,7 @@ bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector> &b GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); return false; } - vector output_dims; + std::vector output_dims; if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); return false; @@ -385,8 +381,8 @@ Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { /// @return Status /// Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, - const vector> &batch_shape, - const vector> &combined_batch) { + const std::vector> &batch_shape, + const std::vector> &combined_batch) { NodePtr pred_value_node = pred_value->GetOwnerNode(); // Create SwitchCase node const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; @@ -429,31 +425,11 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s return false; } - size_t num = output_shape.size(); - size_t dim_num = output_shape[0].size(); - for (size_t i = 1; i < num; i++) { - size_t tmp_dim_num = output_shape[i].size(); - if (dim_num != tmp_dim_num) { - GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); + for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) { + if (output_shape[0] != *iter) { return false; } } - - if (dim_num == 0) { - return true; - } - - for (size_t i = 0; i < dim_num; i++) { - int64_t dim_value = output_shape[0][i]; - for (size_t j = 1; j < num; j++) { - int64_t tmp_dim_value = output_shape[j][i]; - if (dim_value != tmp_dim_value) { - GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, - dim_value, j, tmp_dim_value); - return false; - } - } - } return true; } @@ -468,8 +444,8 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s /// NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, const OutDataAnchorPtr &pred_value, - const vector> &batch_shape, - const vector> &combined_batch) { + const std::vector> &batch_shape, + const std::vector> &combined_batch) { OpDescPtr op_desc = MakeShared(name, STREAMSWITCHN); if (op_desc == nullptr) { GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); @@ -512,7 +488,7 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } - const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); + const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) { GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str()); return nullptr; diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index 529480a6..f75a104f 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -72,25 +72,26 @@ Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &gra std::unordered_map> cond_switch_map; for (const NodePtr &node : graph->GetDirectNode()) { GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); - if ((type == SWITCH) || (type == REFSWITCH)) { - InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); - GE_CHECK_NOTNULL(in_cond_anchor); - OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { - GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); - return FAILED; - } + if ((type != SWITCH) && (type != REFSWITCH)) { + continue; + } + InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); + GE_CHECK_NOTNULL(in_cond_anchor); + OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + if (FindSwitchCondInput(peer_out_anchor) != SUCCESS) { + GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); + return FAILED; + } - NodePtr cond_node = peer_out_anchor->GetOwnerNode(); - auto iter = cond_switch_map.find(cond_node); - if (iter == cond_switch_map.end()) { - cond_switch_map[cond_node] = { node }; - } else { - iter->second.emplace_back(node); - } - switch_nodes_.emplace_back(node); + NodePtr cond_node = peer_out_anchor->GetOwnerNode(); + auto iter = cond_switch_map.find(cond_node); + if (iter == cond_switch_map.end()) { + cond_switch_map[cond_node] = { node }; + } else { + iter->second.emplace_back(node); } + switch_nodes_.emplace_back(node); } MarkCycleDependence(cond_switch_map); @@ -241,10 +242,6 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou if (idx == SWITCH_DATA_INPUT) { peer_data_anchor = peer_out_anchor; } else { - if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { - GELOGE(FAILED, "Find pred_input for switch_node %s failed.", switch_node->GetName().c_str()); - return FAILED; - } peer_cond_anchor = peer_out_anchor; } } @@ -254,15 +251,14 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou /// /// @brief Find Switch cond input -/// @param [in] pass_switch_flag /// @param [out] peer_cond_anchor /// @return Status /// -Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { +Status SwitchToStreamSwitchPass::FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor) { NodePtr tmp_node = nullptr; - string type; - bool need_pass_type = true; - while (need_pass_type) { + std::string type; + bool pass_flag = true; + while (pass_flag) { if (tmp_node == nullptr) { tmp_node = peer_cond_anchor->GetOwnerNode(); } else { @@ -274,7 +270,7 @@ Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutD } GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); - need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); + pass_flag = ((type == SWITCH) || (type == REFSWITCH)); } return SUCCESS; @@ -369,7 +365,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ } } else { int64_t switch_group_id = GetGroupId(stream_switch); - map>> switch_group_map; + std::map>> switch_group_map; std::list false_node_list; std::list true_node_list; std::list &node_list = true_branch_flag ? true_node_list : false_node_list; @@ -389,7 +385,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ /// @return group_id /// int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { - string tailing_optimization_option; + std::string tailing_optimization_option; bool is_tailing_optimization = false; if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { // "1" means it's True from frontend option @@ -400,7 +396,7 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { return 0; } - string hccl_group_id; + std::string hccl_group_id; if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); return 0; @@ -432,6 +428,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); OutDataAnchorPtr peer_cond_anchor = iter->first; + GE_CHECK_NOTNULL(peer_cond_anchor); NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); @@ -549,6 +546,7 @@ NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, con NodePtr cast_node = graph->AddNode(cast_desc); GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); + // Cast node has and only has one input GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); return cast_node; @@ -614,24 +612,24 @@ Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_no return INTERNAL_ERROR; } - for (const NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), + for (const NodePtr &in_ctrl_node : switch_node->GetInControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), "Remove ctl edge failed."); - GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), + GE_IF_BOOL_EXEC(!in_ctrl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), "Add ctl edge failed."); }); - GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); - if (same_cond_switch.count(in_ctl_node) > 0) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), + GE_IF_BOOL_EXEC(in_ctrl_node->GetType() != STREAMSWITCH, continue); + if (same_cond_switch.count(in_ctrl_node) > 0) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), "Remove ctl edge failed."); continue; } - auto find_res1 = switch_node_map_.find(in_ctl_node); + auto find_res1 = switch_node_map_.find(in_ctrl_node); GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { - GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctrl_node->GetName().c_str()); return INTERNAL_ERROR; }); auto find_res2 = find_res1->second.find(orig_switch_name); diff --git a/ge/graph/passes/switch_to_stream_switch_pass.h b/ge/graph/passes/switch_to_stream_switch_pass.h index 48725230..7070b647 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.h +++ b/ge/graph/passes/switch_to_stream_switch_pass.h @@ -131,11 +131,10 @@ class SwitchToStreamSwitchPass : public GraphPass { /// /// @brief Find Switch cond input - /// @param [in] pass_switch_flag /// @param [out] peer_cond_anchor /// @return Status /// - Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor); + Status FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor); /// /// @brief Create StreamSwitch Node