Browse Source

modify for security check of control pass

tags/v1.2.0
chenyemeng 3 years ago
parent
commit
9909a46e2a
8 changed files with 119 additions and 173 deletions
  1. +23
    -31
      ge/graph/passes/attach_stream_label_pass.cc
  2. +8
    -2
      ge/graph/passes/cond_remove_pass.cc
  3. +9
    -16
      ge/graph/passes/enter_pass.cc
  4. +7
    -20
      ge/graph/passes/for_pass.cc
  5. +2
    -7
      ge/graph/passes/merge_pass.cc
  6. +32
    -56
      ge/graph/passes/multi_batch_pass.cc
  7. +37
    -39
      ge/graph/passes/switch_to_stream_switch_pass.cc
  8. +1
    -2
      ge/graph/passes/switch_to_stream_switch_pass.h

+ 23
- 31
ge/graph/passes/attach_stream_label_pass.cc View File

@@ -24,11 +24,7 @@ Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) {

FindNodes(graph);
for (const auto &node : need_label_nodes_) {
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) {
GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str());
}
GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str());
}
GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed.");

@@ -55,13 +51,15 @@ Status AttachStreamLabelPass::ClearStatus() {
///
void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) {
for (const NodePtr &node : graph->GetDirectNode()) {
const std::string &type = node->GetType();
if (type == STREAMSWITCH) {
const auto &op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
continue;
}
const std::string &type = op_desc->GetType();
if ((type == STREAMSWITCH) && op_desc->HasAttr(ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG)) {
stream_switch_nodes_.emplace_back(node);
} else if (type == STREAMMERGE) {
if ((node->GetOpDesc() != nullptr) && !node->GetOpDesc()->HasAttr(ATTR_NAME_NEXT_ITERATION)) {
need_label_nodes_.emplace_back(node);
}
} else if ((type == STREAMMERGE) && !op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) {
need_label_nodes_.emplace_back(node);
} else if ((type == ENTER) || (type == REFENTER)) {
enter_nodes_.emplace_back(node);
}
@@ -83,11 +81,15 @@ void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) {
///
Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) {
std::string stream_label;
if (AttachFlag(node, stream_label) != SUCCESS) {
GELOGE(FAILED, "Attach flag for node %s failed.", node->GetName().c_str());
return FAILED;
}

std::unordered_set<NodePtr> branch_nodes;
std::unordered_set<NodePtr> visited;
std::stack<NodePtr> nodes;
nodes.push(node);

static const std::set<std::string> end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE};
while (!nodes.empty()) {
NodePtr cur_node = nodes.top();
@@ -95,10 +97,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) {
if (visited.count(cur_node) > 0) {
continue;
}
if (AttachFlag(cur_node, stream_label) != SUCCESS) {
GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str());
return FAILED;
}

const std::string &type = cur_node->GetType();
for (const auto &out_node : cur_node->GetOutAllNodes()) {
@@ -115,10 +113,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) {
visited.insert(cur_node);
}

if (node->GetType() == STREAMSWITCH) {
GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed.");
}

for (const NodePtr &tmp_node : branch_nodes) {
GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str());
GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed.");
@@ -148,11 +142,10 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea
GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED,
"StreamSwitch get attr TRUE_BRANCH_STREAM failed.");
stream_label += (value ? "_t" : "_f");
GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed.");
} else if (type == STREAMMERGE) {
stream_label = node->GetName();
GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed.");
} else if ((type == EXIT) || (type == REFEXIT)) {
GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed.");
}

return SUCCESS;
@@ -166,12 +159,13 @@ Status AttachStreamLabelPass::UpdateEnterNode() {
std::unordered_map<NodePtr, std::vector<NodePtr>> enter_active_map;
for (const auto &enter_node : enter_nodes_) {
for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) {
if (out_ctrl_node->GetType() == STREAMACTIVE) {
if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) {
enter_active_map[out_ctrl_node] = {enter_node};
} else {
enter_active_map[out_ctrl_node].emplace_back(enter_node);
}
if (out_ctrl_node->GetType() != STREAMACTIVE) {
continue;
}
if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) {
enter_active_map[out_ctrl_node] = {enter_node};
} else {
enter_active_map[out_ctrl_node].emplace_back(enter_node);
}
}
}
@@ -226,9 +220,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
std::string stream_label;
GE_CHECK_NOTNULL(active_node);
(void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label);

if (stream_label.empty()) {
GELOGW("stream_label of enter_active & enter_nodes is empty.");
GELOGD("stream_label of enter_active %s is empty.", active_node->GetName().c_str());
return SUCCESS;
}

@@ -238,7 +231,6 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed.");
}
}
GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed.");
return SUCCESS;
}



+ 8
- 2
ge/graph/passes/cond_remove_pass.cc View File

@@ -37,6 +37,12 @@ Status CondRemovePass::Run(NodePtr &node) {
OutDataAnchorPtr cond_out_anchor = nullptr;
InDataAnchorPtr cond_in_anchor = nullptr;
Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor);
if (ret == NOT_CHANGED) {
return SUCCESS;
} else if (ret != SUCCESS) {
GELOGE(FAILED, "Get cond_info for node %s failed.", node->GetName().c_str());
return FAILED;
}
int32_t cond_index = 0;
GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str());
bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index);
@@ -322,11 +328,11 @@ Status CondRemovePass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph,
std::string type = node->GetType();
if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) {
if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) {
GELOGE(FAILED, "Get cond_info for if node failed.");
GELOGE(FAILED, "Get cond_info for if/case node failed.");
return FAILED;
}
} else {
GELOGD("no need cond_pass for node %s.", node->GetName().c_str());
GELOGD("no need cond_remove_pass for node %s.", node->GetName().c_str());
return NOT_CHANGED;
}



+ 9
- 16
ge/graph/passes/enter_pass.cc View File

@@ -16,6 +16,7 @@

#include "graph/passes/enter_pass.h"

#include "graph/debug/ge_attr_define.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "graph/utils/graph_utils.h"
@@ -72,33 +73,25 @@ Status EnterPass::Run(NodePtr &node) {
}

Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) {
auto out_nodes_of_in_node = in_node->GetOutAllNodes();
if (out_nodes_of_in_node.size() != kOutNodesNum) {
if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) {
return SUCCESS;
}

if (!node->GetOutControlNodes().empty()) {
bool is_constant_flag = true;
(void)AttrUtils::GetBool(node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant_flag);
if (!is_constant_flag) {
return SUCCESS;
}

for (const auto &out_node : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(out_node);
if (out_node->GetType() == MERGE) {
return SUCCESS;
}
}

GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0)));
auto out_data_anchor = node->GetOutDataAnchor(0);
const auto &out_data_anchor = node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(out_data_anchor);
for (auto peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor));
}

auto graph = node->GetOwnerComputeGraph();
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, node))
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node));
AddNodeDeleted(node);
AddRePassNodesWithInOut(in_node);

return SUCCESS;


+ 7
- 20
ge/graph/passes/for_pass.cc View File

@@ -137,7 +137,7 @@ Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &n
for_info.ctrl_inputs = std::move(ctrl_inputs);
for_info.ctrl_outputs = std::move(ctrl_outputs);

GELOGI("Build for_info for node %s succ.", node->GetName().c_str());
GELOGI("Build for_info for node %s success.", node->GetName().c_str());
return SUCCESS;
}

@@ -159,13 +159,7 @@ OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index
return nullptr;
}

OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
GELOGE(FAILED, "FindInputWithIndex %s:%u failed: peer_out_anchor is NULL.", node->GetName().c_str(), index);
return nullptr;
}

return peer_out_anchor;
return in_data_anchor->GetPeerOutAnchor();
}

///
@@ -186,20 +180,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc
uint32_t input_data_num = node->GetAllInDataAnchorsSize();
for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) {
InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
if (in_data_anchor == nullptr) {
GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index);
return FAILED;
}
GE_IF_BOOL_EXEC(in_data_anchor->GetPeerOutAnchor() == nullptr,
GELOGW("Get null input by index %d from node %s ",
in_data_anchor->GetIdx(), node->GetName().c_str());
continue);
GE_CHECK_NOTNULL(in_data_anchor);
data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor());
}

for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
std::vector<ge::InDataAnchorPtr> peer_in_data_anchors;
for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
peer_in_data_anchors.emplace_back(peer_in_data_anchor);
}
data_outputs.emplace_back(peer_in_data_anchors);
@@ -207,13 +194,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc

InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor);
for (auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
ctrl_inputs.emplace_back(peer_out_ctrl_anchor);
}

OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_ctrl_anchor);
for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
ctrl_outputs.emplace_back(peer_in_ctrl_anchor);
}



+ 2
- 7
ge/graph/passes/merge_pass.cc View File

@@ -21,16 +21,12 @@
#include <vector>

#include "framework/common/debug/ge_log.h"
#include "common/ge_inner_error_codes.h"
#include "common/ge/ge_util.h"
#include "graph/common/omg_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/passes/pass_utils.h"

using domi::PARAM_INVALID;
using domi::SUCCESS;

namespace ge {
const int kValueIndexOutputIndex = 1;

@@ -47,13 +43,12 @@ Status MergePass::Run(NodePtr &node) {
return SUCCESS;
}

auto out_data_anchors = node->GetAllOutDataAnchors();
if (out_data_anchors.empty()) {
if (node->GetAllOutDataAnchors().empty()) {
GELOGE(PARAM_INVALID, "[%s] Merge node output anchor is empty", node->GetName().c_str());
return PARAM_INVALID;
}

auto in_data_nodes = node->GetInDataNodes();
const auto &in_data_nodes = node->GetInDataNodes();
switch (in_data_nodes.size()) {
case 0: {
/// Case A: input_count = 0, the output of merge node is inactive as well


+ 32
- 56
ge/graph/passes/multi_batch_pass.cc View File

@@ -22,9 +22,6 @@
#include "graph/common/omg_util.h"
#include "graph/utils/type_utils.h"

using std::string;
using std::vector;

namespace ge {
Status MultiBatchPass::Run(ComputeGraphPtr graph) {
GELOGD("MultiBatchPass Enter");
@@ -53,7 +50,7 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) {
return FAILED;
}
std::vector<std::vector<int64_t>> batch_shape;
vector<vector<int64_t>> combined_batch;
std::vector<std::vector<int64_t>> combined_batch;
if (!CheckSwitchN(batch_shape, combined_batch)) {
GELOGE(FAILED, "CheckSwitchN failed.");
return FAILED;
@@ -104,6 +101,7 @@ Status MultiBatchPass::ClearStatus() {
///
Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) {
const auto &func_desc = case_node->GetOpDesc();
GE_CHECK_NOTNULL(func_desc);
if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str());
return SUCCESS;
@@ -114,7 +112,7 @@ Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr
const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]);
GE_CHECK_NOTNULL(subgraph);

const string batch_label = "Batch_" + std::to_string(i);
const std::string batch_label = "Batch_" + std::to_string(i);
for (const auto &node : subgraph->GetDirectNode()) {
(void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
}
@@ -139,12 +137,12 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor
continue;
}

InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
if (in_data_anchor == nullptr) {
GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str());
return FAILED;
}
OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor();
const auto &pred_input = in_data_anchor->GetPeerOutAnchor();
if (pred_input == nullptr) {
GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str());
return FAILED;
@@ -178,12 +176,10 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor
/// @return Status
///
Status MultiBatchPass::GetDynamicType() {
for (const auto &switchn : switch_n_nodes_) {
auto switchn_desc = switchn->GetOpDesc();
GE_CHECK_NOTNULL(switchn_desc);
for (const auto &switch_n : switch_n_nodes_) {
int32_t dynamic_type = static_cast<int32_t>(FIXED);
if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) {
GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str());
if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) {
GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str());
return FAILED;
}
if (dynamic_type == static_cast<int32_t>(FIXED)) {
@@ -191,7 +187,7 @@ Status MultiBatchPass::GetDynamicType() {
return FAILED;
}
if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) {
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.",
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.",
dynamic_type, dynamic_type_);
return FAILED;
}
@@ -212,21 +208,19 @@ Status MultiBatchPass::GetDynamicType() {
Status MultiBatchPass::GetUserDesignateShape() {
data_name_order_.clear();
bool first_check = true;
for (const auto &switchn : switch_n_nodes_) {
auto switchn_desc = switchn->GetOpDesc();
GE_CHECK_NOTNULL(switchn_desc);
vector<string> cur_switchn_data_name_order;
if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) {
GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str());
for (const auto &switch_n : switch_n_nodes_) {
std::vector<std::string> cur_data_name_order;
if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) {
GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str());
return FAILED;
}
if (first_check) {
data_name_order_ = cur_switchn_data_name_order;
data_name_order_ = cur_data_name_order;
first_check = false;
} else {
if (data_name_order_ != cur_switchn_data_name_order) {
if (data_name_order_ != cur_data_name_order) {
GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.",
switchn->GetName().c_str());
switch_n->GetName().c_str());
return FAILED;
}
}
@@ -245,7 +239,8 @@ Status MultiBatchPass::GetUserDesignateShape() {
/// @param [out] combined_batch
/// @return bool
///
bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<vector<int64_t>> &combined_batch) {
bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape,
std::vector<std::vector<int64_t>> &combined_batch) {
// Check if output_num of different SwitchN is same
uint32_t batch_num = 0;
for (const NodePtr &node : switch_n_nodes_) {
@@ -281,7 +276,8 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v
}
size_t tmp_combined_dim_num = combined_batch[i].size();
if (combined_dim_num != tmp_combined_dim_num) {
GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num);
GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.",
combined_dim_num, i, tmp_combined_dim_num);
return false;
}
}
@@ -296,11 +292,11 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v
/// @param [out] combined_batch
/// @return bool
///
bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &batch_shape,
vector<vector<int64_t>> &combined_batch) {
bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector<std::vector<int64_t>> &batch_shape,
std::vector<std::vector<int64_t>> &combined_batch) {
// Check if output_shape of different SwitchN is same
vector<vector<int64_t>> idx_batch_shape;
vector<vector<int64_t>> idx_combined_batch;
std::vector<std::vector<int64_t>> idx_batch_shape;
std::vector<std::vector<int64_t>> idx_combined_batch;
for (uint32_t i = 0; i < batch_num; i++) {
idx_batch_shape.clear();
idx_combined_batch.clear();
@@ -310,7 +306,7 @@ bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &b
GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str());
return false;
}
vector<int64_t> output_dims;
std::vector<int64_t> output_dims;
if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) {
GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i);
return false;
@@ -385,8 +381,8 @@ Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) {
/// @return Status
///
Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value,
const vector<vector<int64_t>> &batch_shape,
const vector<vector<int64_t>> &combined_batch) {
const std::vector<std::vector<int64_t>> &batch_shape,
const std::vector<std::vector<int64_t>> &combined_batch) {
NodePtr pred_value_node = pred_value->GetOwnerNode();
// Create SwitchCase node
const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN;
@@ -429,31 +425,11 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s
return false;
}

size_t num = output_shape.size();
size_t dim_num = output_shape[0].size();
for (size_t i = 1; i < num; i++) {
size_t tmp_dim_num = output_shape[i].size();
if (dim_num != tmp_dim_num) {
GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num);
for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) {
if (output_shape[0] != *iter) {
return false;
}
}

if (dim_num == 0) {
return true;
}

for (size_t i = 0; i < dim_num; i++) {
int64_t dim_value = output_shape[0][i];
for (size_t j = 1; j < num; j++) {
int64_t tmp_dim_value = output_shape[j][i];
if (dim_value != tmp_dim_value) {
GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i,
dim_value, j, tmp_dim_value);
return false;
}
}
}
return true;
}

@@ -468,8 +444,8 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s
///
NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name,
const OutDataAnchorPtr &pred_value,
const vector<vector<int64_t>> &batch_shape,
const vector<vector<int64_t>> &combined_batch) {
const std::vector<std::vector<int64_t>> &batch_shape,
const std::vector<std::vector<int64_t>> &combined_batch) {
OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN);
if (op_desc == nullptr) {
GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str());
@@ -512,7 +488,7 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const
GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str());
return nullptr;
}
const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i);
const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i);
if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) {
GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str());
return nullptr;


+ 37
- 39
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -72,25 +72,26 @@ Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &gra
std::unordered_map<NodePtr, std::vector<NodePtr>> cond_switch_map;
for (const NodePtr &node : graph->GetDirectNode()) {
GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed.");
if ((type == SWITCH) || (type == REFSWITCH)) {
InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
GE_CHECK_NOTNULL(in_cond_anchor);
OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);
if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) {
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str());
return FAILED;
}
if ((type != SWITCH) && (type != REFSWITCH)) {
continue;
}
InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
GE_CHECK_NOTNULL(in_cond_anchor);
OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);
if (FindSwitchCondInput(peer_out_anchor) != SUCCESS) {
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str());
return FAILED;
}

NodePtr cond_node = peer_out_anchor->GetOwnerNode();
auto iter = cond_switch_map.find(cond_node);
if (iter == cond_switch_map.end()) {
cond_switch_map[cond_node] = { node };
} else {
iter->second.emplace_back(node);
}
switch_nodes_.emplace_back(node);
NodePtr cond_node = peer_out_anchor->GetOwnerNode();
auto iter = cond_switch_map.find(cond_node);
if (iter == cond_switch_map.end()) {
cond_switch_map[cond_node] = { node };
} else {
iter->second.emplace_back(node);
}
switch_nodes_.emplace_back(node);
}

MarkCycleDependence(cond_switch_map);
@@ -241,10 +242,6 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou
if (idx == SWITCH_DATA_INPUT) {
peer_data_anchor = peer_out_anchor;
} else {
if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) {
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", switch_node->GetName().c_str());
return FAILED;
}
peer_cond_anchor = peer_out_anchor;
}
}
@@ -254,15 +251,14 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou

///
/// @brief Find Switch cond input
/// @param [in] pass_switch_flag
/// @param [out] peer_cond_anchor
/// @return Status
///
Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) {
Status SwitchToStreamSwitchPass::FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor) {
NodePtr tmp_node = nullptr;
string type;
bool need_pass_type = true;
while (need_pass_type) {
std::string type;
bool pass_flag = true;
while (pass_flag) {
if (tmp_node == nullptr) {
tmp_node = peer_cond_anchor->GetOwnerNode();
} else {
@@ -274,7 +270,7 @@ Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutD
}

GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed.");
need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH)));
pass_flag = ((type == SWITCH) || (type == REFSWITCH));
}

return SUCCESS;
@@ -369,7 +365,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_
}
} else {
int64_t switch_group_id = GetGroupId(stream_switch);
map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map;
std::map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map;
std::list<NodePtr> false_node_list;
std::list<NodePtr> true_node_list;
std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list;
@@ -389,7 +385,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_
/// @return group_id
///
int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) {
string tailing_optimization_option;
std::string tailing_optimization_option;
bool is_tailing_optimization = false;
if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) {
// "1" means it's True from frontend option
@@ -400,7 +396,7 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) {
return 0;
}

string hccl_group_id;
std::string hccl_group_id;
if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) {
GELOGI("Node %s can not find hccl group id.", node->GetName().c_str());
return 0;
@@ -432,6 +428,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end());

OutDataAnchorPtr peer_cond_anchor = iter->first;
GE_CHECK_NOTNULL(peer_cond_anchor);
NodePtr cond_node = peer_cond_anchor->GetOwnerNode();
GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str());

@@ -549,6 +546,7 @@ NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, con

NodePtr cast_node = graph->AddNode(cast_desc);
GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed.");
// Cast node has and only has one input
GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed.");

return cast_node;
@@ -614,24 +612,24 @@ Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_no
return INTERNAL_ERROR;
}

for (const NodePtr &in_ctl_node : switch_node->GetInControlNodes()) {
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()),
for (const NodePtr &in_ctrl_node : switch_node->GetInControlNodes()) {
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()),
"Remove ctl edge failed.");
GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), {
GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
GE_IF_BOOL_EXEC(!in_ctrl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), {
GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
"Add ctl edge failed.");
});

GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue);
if (same_cond_switch.count(in_ctl_node) > 0) {
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
GE_IF_BOOL_EXEC(in_ctrl_node->GetType() != STREAMSWITCH, continue);
if (same_cond_switch.count(in_ctrl_node) > 0) {
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
"Remove ctl edge failed.");
continue;
}

auto find_res1 = switch_node_map_.find(in_ctl_node);
auto find_res1 = switch_node_map_.find(in_ctrl_node);
GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), {
GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str());
GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctrl_node->GetName().c_str());
return INTERNAL_ERROR;
});
auto find_res2 = find_res1->second.find(orig_switch_name);


+ 1
- 2
ge/graph/passes/switch_to_stream_switch_pass.h View File

@@ -131,11 +131,10 @@ class SwitchToStreamSwitchPass : public GraphPass {

///
/// @brief Find Switch cond input
/// @param [in] pass_switch_flag
/// @param [out] peer_cond_anchor
/// @return Status
///
Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor);
Status FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor);

///
/// @brief Create StreamSwitch Node


Loading…
Cancel
Save