@@ -24,11 +24,7 @@ Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { | |||
FindNodes(graph); | |||
for (const auto &node : need_label_nodes_) { | |||
OpDescPtr op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(op_desc); | |||
if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { | |||
GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); | |||
} | |||
GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); | |||
} | |||
GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); | |||
@@ -55,13 +51,15 @@ Status AttachStreamLabelPass::ClearStatus() { | |||
/// | |||
void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { | |||
for (const NodePtr &node : graph->GetDirectNode()) { | |||
const std::string &type = node->GetType(); | |||
if (type == STREAMSWITCH) { | |||
const auto &op_desc = node->GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
continue; | |||
} | |||
const std::string &type = op_desc->GetType(); | |||
if ((type == STREAMSWITCH) && op_desc->HasAttr(ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG)) { | |||
stream_switch_nodes_.emplace_back(node); | |||
} else if (type == STREAMMERGE) { | |||
if ((node->GetOpDesc() != nullptr) && !node->GetOpDesc()->HasAttr(ATTR_NAME_NEXT_ITERATION)) { | |||
need_label_nodes_.emplace_back(node); | |||
} | |||
} else if ((type == STREAMMERGE) && !op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { | |||
need_label_nodes_.emplace_back(node); | |||
} else if ((type == ENTER) || (type == REFENTER)) { | |||
enter_nodes_.emplace_back(node); | |||
} | |||
@@ -83,11 +81,15 @@ void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { | |||
/// | |||
Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | |||
std::string stream_label; | |||
if (AttachFlag(node, stream_label) != SUCCESS) { | |||
GELOGE(FAILED, "Attach flag for node %s failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
std::unordered_set<NodePtr> branch_nodes; | |||
std::unordered_set<NodePtr> visited; | |||
std::stack<NodePtr> nodes; | |||
nodes.push(node); | |||
static const std::set<std::string> end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; | |||
while (!nodes.empty()) { | |||
NodePtr cur_node = nodes.top(); | |||
@@ -95,10 +97,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | |||
if (visited.count(cur_node) > 0) { | |||
continue; | |||
} | |||
if (AttachFlag(cur_node, stream_label) != SUCCESS) { | |||
GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
const std::string &type = cur_node->GetType(); | |||
for (const auto &out_node : cur_node->GetOutAllNodes()) { | |||
@@ -115,10 +113,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | |||
visited.insert(cur_node); | |||
} | |||
if (node->GetType() == STREAMSWITCH) { | |||
GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); | |||
} | |||
for (const NodePtr &tmp_node : branch_nodes) { | |||
GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); | |||
GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); | |||
@@ -148,11 +142,10 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea | |||
GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, | |||
"StreamSwitch get attr TRUE_BRANCH_STREAM failed."); | |||
stream_label += (value ? "_t" : "_f"); | |||
GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); | |||
} else if (type == STREAMMERGE) { | |||
stream_label = node->GetName(); | |||
GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); | |||
} else if ((type == EXIT) || (type == REFEXIT)) { | |||
GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); | |||
} | |||
return SUCCESS; | |||
@@ -166,12 +159,13 @@ Status AttachStreamLabelPass::UpdateEnterNode() { | |||
std::unordered_map<NodePtr, std::vector<NodePtr>> enter_active_map; | |||
for (const auto &enter_node : enter_nodes_) { | |||
for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { | |||
if (out_ctrl_node->GetType() == STREAMACTIVE) { | |||
if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { | |||
enter_active_map[out_ctrl_node] = {enter_node}; | |||
} else { | |||
enter_active_map[out_ctrl_node].emplace_back(enter_node); | |||
} | |||
if (out_ctrl_node->GetType() != STREAMACTIVE) { | |||
continue; | |||
} | |||
if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { | |||
enter_active_map[out_ctrl_node] = {enter_node}; | |||
} else { | |||
enter_active_map[out_ctrl_node].emplace_back(enter_node); | |||
} | |||
} | |||
} | |||
@@ -226,9 +220,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||
std::string stream_label; | |||
GE_CHECK_NOTNULL(active_node); | |||
(void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); | |||
if (stream_label.empty()) { | |||
GELOGW("stream_label of enter_active & enter_nodes is empty."); | |||
GELOGD("stream_label of enter_active %s is empty.", active_node->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
@@ -238,7 +231,6 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | |||
} | |||
} | |||
GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed."); | |||
return SUCCESS; | |||
} | |||
@@ -37,6 +37,12 @@ Status CondRemovePass::Run(NodePtr &node) { | |||
OutDataAnchorPtr cond_out_anchor = nullptr; | |||
InDataAnchorPtr cond_in_anchor = nullptr; | |||
Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); | |||
if (ret == NOT_CHANGED) { | |||
return SUCCESS; | |||
} else if (ret != SUCCESS) { | |||
GELOGE(FAILED, "Get cond_info for node %s failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
int32_t cond_index = 0; | |||
GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str()); | |||
bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index); | |||
@@ -322,11 +328,11 @@ Status CondRemovePass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, | |||
std::string type = node->GetType(); | |||
if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) { | |||
if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { | |||
GELOGE(FAILED, "Get cond_info for if node failed."); | |||
GELOGE(FAILED, "Get cond_info for if/case node failed."); | |||
return FAILED; | |||
} | |||
} else { | |||
GELOGD("no need cond_pass for node %s.", node->GetName().c_str()); | |||
GELOGD("no need cond_remove_pass for node %s.", node->GetName().c_str()); | |||
return NOT_CHANGED; | |||
} | |||
@@ -16,6 +16,7 @@ | |||
#include "graph/passes/enter_pass.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "framework/common/debug/log.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -72,33 +73,25 @@ Status EnterPass::Run(NodePtr &node) { | |||
} | |||
Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { | |||
auto out_nodes_of_in_node = in_node->GetOutAllNodes(); | |||
if (out_nodes_of_in_node.size() != kOutNodesNum) { | |||
if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { | |||
return SUCCESS; | |||
} | |||
if (!node->GetOutControlNodes().empty()) { | |||
bool is_constant_flag = true; | |||
(void)AttrUtils::GetBool(node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant_flag); | |||
if (!is_constant_flag) { | |||
return SUCCESS; | |||
} | |||
for (const auto &out_node : node->GetOutDataNodes()) { | |||
GE_CHECK_NOTNULL(out_node); | |||
if (out_node->GetType() == MERGE) { | |||
return SUCCESS; | |||
} | |||
} | |||
GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); | |||
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); | |||
auto out_data_anchor = node->GetOutDataAnchor(0); | |||
const auto &out_data_anchor = node->GetOutDataAnchor(0); | |||
GE_CHECK_NOTNULL(out_data_anchor); | |||
for (auto peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); | |||
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); | |||
} | |||
auto graph = node->GetOwnerComputeGraph(); | |||
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, node)) | |||
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)); | |||
AddNodeDeleted(node); | |||
AddRePassNodesWithInOut(in_node); | |||
return SUCCESS; | |||
@@ -137,7 +137,7 @@ Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &n | |||
for_info.ctrl_inputs = std::move(ctrl_inputs); | |||
for_info.ctrl_outputs = std::move(ctrl_outputs); | |||
GELOGI("Build for_info for node %s succ.", node->GetName().c_str()); | |||
GELOGI("Build for_info for node %s success.", node->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
@@ -159,13 +159,7 @@ OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index | |||
return nullptr; | |||
} | |||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
if (peer_out_anchor == nullptr) { | |||
GELOGE(FAILED, "FindInputWithIndex %s:%u failed: peer_out_anchor is NULL.", node->GetName().c_str(), index); | |||
return nullptr; | |||
} | |||
return peer_out_anchor; | |||
return in_data_anchor->GetPeerOutAnchor(); | |||
} | |||
/// | |||
@@ -186,20 +180,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc | |||
uint32_t input_data_num = node->GetAllInDataAnchorsSize(); | |||
for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) { | |||
InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); | |||
if (in_data_anchor == nullptr) { | |||
GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index); | |||
return FAILED; | |||
} | |||
GE_IF_BOOL_EXEC(in_data_anchor->GetPeerOutAnchor() == nullptr, | |||
GELOGW("Get null input by index %d from node %s ", | |||
in_data_anchor->GetIdx(), node->GetName().c_str()); | |||
continue); | |||
GE_CHECK_NOTNULL(in_data_anchor); | |||
data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor()); | |||
} | |||
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
std::vector<ge::InDataAnchorPtr> peer_in_data_anchors; | |||
for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
peer_in_data_anchors.emplace_back(peer_in_data_anchor); | |||
} | |||
data_outputs.emplace_back(peer_in_data_anchors); | |||
@@ -207,13 +194,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc | |||
InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor(); | |||
GE_CHECK_NOTNULL(in_ctrl_anchor); | |||
for (auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { | |||
for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { | |||
ctrl_inputs.emplace_back(peer_out_ctrl_anchor); | |||
} | |||
OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor(); | |||
GE_CHECK_NOTNULL(out_ctrl_anchor); | |||
for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
ctrl_outputs.emplace_back(peer_in_ctrl_anchor); | |||
} | |||
@@ -21,16 +21,12 @@ | |||
#include <vector> | |||
#include "framework/common/debug/ge_log.h" | |||
#include "common/ge_inner_error_codes.h" | |||
#include "common/ge/ge_util.h" | |||
#include "graph/common/omg_util.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/passes/pass_utils.h" | |||
using domi::PARAM_INVALID; | |||
using domi::SUCCESS; | |||
namespace ge { | |||
const int kValueIndexOutputIndex = 1; | |||
@@ -47,13 +43,12 @@ Status MergePass::Run(NodePtr &node) { | |||
return SUCCESS; | |||
} | |||
auto out_data_anchors = node->GetAllOutDataAnchors(); | |||
if (out_data_anchors.empty()) { | |||
if (node->GetAllOutDataAnchors().empty()) { | |||
GELOGE(PARAM_INVALID, "[%s] Merge node output anchor is empty", node->GetName().c_str()); | |||
return PARAM_INVALID; | |||
} | |||
auto in_data_nodes = node->GetInDataNodes(); | |||
const auto &in_data_nodes = node->GetInDataNodes(); | |||
switch (in_data_nodes.size()) { | |||
case 0: { | |||
/// Case A: input_count = 0, the output of merge node is inactive as well | |||
@@ -22,9 +22,6 @@ | |||
#include "graph/common/omg_util.h" | |||
#include "graph/utils/type_utils.h" | |||
using std::string; | |||
using std::vector; | |||
namespace ge { | |||
Status MultiBatchPass::Run(ComputeGraphPtr graph) { | |||
GELOGD("MultiBatchPass Enter"); | |||
@@ -53,7 +50,7 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { | |||
return FAILED; | |||
} | |||
std::vector<std::vector<int64_t>> batch_shape; | |||
vector<vector<int64_t>> combined_batch; | |||
std::vector<std::vector<int64_t>> combined_batch; | |||
if (!CheckSwitchN(batch_shape, combined_batch)) { | |||
GELOGE(FAILED, "CheckSwitchN failed."); | |||
return FAILED; | |||
@@ -104,6 +101,7 @@ Status MultiBatchPass::ClearStatus() { | |||
/// | |||
Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) { | |||
const auto &func_desc = case_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(func_desc); | |||
if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { | |||
GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str()); | |||
return SUCCESS; | |||
@@ -114,7 +112,7 @@ Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr | |||
const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]); | |||
GE_CHECK_NOTNULL(subgraph); | |||
const string batch_label = "Batch_" + std::to_string(i); | |||
const std::string batch_label = "Batch_" + std::to_string(i); | |||
for (const auto &node : subgraph->GetDirectNode()) { | |||
(void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); | |||
} | |||
@@ -139,12 +137,12 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor | |||
continue; | |||
} | |||
InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||
const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||
if (in_data_anchor == nullptr) { | |||
GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor(); | |||
const auto &pred_input = in_data_anchor->GetPeerOutAnchor(); | |||
if (pred_input == nullptr) { | |||
GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); | |||
return FAILED; | |||
@@ -178,12 +176,10 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor | |||
/// @return Status | |||
/// | |||
Status MultiBatchPass::GetDynamicType() { | |||
for (const auto &switchn : switch_n_nodes_) { | |||
auto switchn_desc = switchn->GetOpDesc(); | |||
GE_CHECK_NOTNULL(switchn_desc); | |||
for (const auto &switch_n : switch_n_nodes_) { | |||
int32_t dynamic_type = static_cast<int32_t>(FIXED); | |||
if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) { | |||
GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str()); | |||
if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) { | |||
GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str()); | |||
return FAILED; | |||
} | |||
if (dynamic_type == static_cast<int32_t>(FIXED)) { | |||
@@ -191,7 +187,7 @@ Status MultiBatchPass::GetDynamicType() { | |||
return FAILED; | |||
} | |||
if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) { | |||
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.", | |||
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.", | |||
dynamic_type, dynamic_type_); | |||
return FAILED; | |||
} | |||
@@ -212,21 +208,19 @@ Status MultiBatchPass::GetDynamicType() { | |||
Status MultiBatchPass::GetUserDesignateShape() { | |||
data_name_order_.clear(); | |||
bool first_check = true; | |||
for (const auto &switchn : switch_n_nodes_) { | |||
auto switchn_desc = switchn->GetOpDesc(); | |||
GE_CHECK_NOTNULL(switchn_desc); | |||
vector<string> cur_switchn_data_name_order; | |||
if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) { | |||
GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str()); | |||
for (const auto &switch_n : switch_n_nodes_) { | |||
std::vector<std::string> cur_data_name_order; | |||
if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) { | |||
GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str()); | |||
return FAILED; | |||
} | |||
if (first_check) { | |||
data_name_order_ = cur_switchn_data_name_order; | |||
data_name_order_ = cur_data_name_order; | |||
first_check = false; | |||
} else { | |||
if (data_name_order_ != cur_switchn_data_name_order) { | |||
if (data_name_order_ != cur_data_name_order) { | |||
GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.", | |||
switchn->GetName().c_str()); | |||
switch_n->GetName().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
@@ -245,7 +239,8 @@ Status MultiBatchPass::GetUserDesignateShape() { | |||
/// @param [out] combined_batch | |||
/// @return bool | |||
/// | |||
bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<vector<int64_t>> &combined_batch) { | |||
bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape, | |||
std::vector<std::vector<int64_t>> &combined_batch) { | |||
// Check if output_num of different SwitchN is same | |||
uint32_t batch_num = 0; | |||
for (const NodePtr &node : switch_n_nodes_) { | |||
@@ -281,7 +276,8 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v | |||
} | |||
size_t tmp_combined_dim_num = combined_batch[i].size(); | |||
if (combined_dim_num != tmp_combined_dim_num) { | |||
GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num); | |||
GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", | |||
combined_dim_num, i, tmp_combined_dim_num); | |||
return false; | |||
} | |||
} | |||
@@ -296,11 +292,11 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v | |||
/// @param [out] combined_batch | |||
/// @return bool | |||
/// | |||
bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &batch_shape, | |||
vector<vector<int64_t>> &combined_batch) { | |||
bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector<std::vector<int64_t>> &batch_shape, | |||
std::vector<std::vector<int64_t>> &combined_batch) { | |||
// Check if output_shape of different SwitchN is same | |||
vector<vector<int64_t>> idx_batch_shape; | |||
vector<vector<int64_t>> idx_combined_batch; | |||
std::vector<std::vector<int64_t>> idx_batch_shape; | |||
std::vector<std::vector<int64_t>> idx_combined_batch; | |||
for (uint32_t i = 0; i < batch_num; i++) { | |||
idx_batch_shape.clear(); | |||
idx_combined_batch.clear(); | |||
@@ -310,7 +306,7 @@ bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &b | |||
GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); | |||
return false; | |||
} | |||
vector<int64_t> output_dims; | |||
std::vector<int64_t> output_dims; | |||
if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { | |||
GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); | |||
return false; | |||
@@ -385,8 +381,8 @@ Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { | |||
/// @return Status | |||
/// | |||
Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, | |||
const vector<vector<int64_t>> &batch_shape, | |||
const vector<vector<int64_t>> &combined_batch) { | |||
const std::vector<std::vector<int64_t>> &batch_shape, | |||
const std::vector<std::vector<int64_t>> &combined_batch) { | |||
NodePtr pred_value_node = pred_value->GetOwnerNode(); | |||
// Create SwitchCase node | |||
const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; | |||
@@ -429,31 +425,11 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s | |||
return false; | |||
} | |||
size_t num = output_shape.size(); | |||
size_t dim_num = output_shape[0].size(); | |||
for (size_t i = 1; i < num; i++) { | |||
size_t tmp_dim_num = output_shape[i].size(); | |||
if (dim_num != tmp_dim_num) { | |||
GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); | |||
for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) { | |||
if (output_shape[0] != *iter) { | |||
return false; | |||
} | |||
} | |||
if (dim_num == 0) { | |||
return true; | |||
} | |||
for (size_t i = 0; i < dim_num; i++) { | |||
int64_t dim_value = output_shape[0][i]; | |||
for (size_t j = 1; j < num; j++) { | |||
int64_t tmp_dim_value = output_shape[j][i]; | |||
if (dim_value != tmp_dim_value) { | |||
GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, | |||
dim_value, j, tmp_dim_value); | |||
return false; | |||
} | |||
} | |||
} | |||
return true; | |||
} | |||
@@ -468,8 +444,8 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s | |||
/// | |||
NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, | |||
const OutDataAnchorPtr &pred_value, | |||
const vector<vector<int64_t>> &batch_shape, | |||
const vector<vector<int64_t>> &combined_batch) { | |||
const std::vector<std::vector<int64_t>> &batch_shape, | |||
const std::vector<std::vector<int64_t>> &combined_batch) { | |||
OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN); | |||
if (op_desc == nullptr) { | |||
GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); | |||
@@ -512,7 +488,7 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const | |||
GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); | |||
return nullptr; | |||
} | |||
const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); | |||
const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); | |||
if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) { | |||
GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str()); | |||
return nullptr; | |||
@@ -72,25 +72,26 @@ Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &gra | |||
std::unordered_map<NodePtr, std::vector<NodePtr>> cond_switch_map; | |||
for (const NodePtr &node : graph->GetDirectNode()) { | |||
GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); | |||
if ((type == SWITCH) || (type == REFSWITCH)) { | |||
InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||
GE_CHECK_NOTNULL(in_cond_anchor); | |||
OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(peer_out_anchor); | |||
if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { | |||
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
if ((type != SWITCH) && (type != REFSWITCH)) { | |||
continue; | |||
} | |||
InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||
GE_CHECK_NOTNULL(in_cond_anchor); | |||
OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(peer_out_anchor); | |||
if (FindSwitchCondInput(peer_out_anchor) != SUCCESS) { | |||
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
NodePtr cond_node = peer_out_anchor->GetOwnerNode(); | |||
auto iter = cond_switch_map.find(cond_node); | |||
if (iter == cond_switch_map.end()) { | |||
cond_switch_map[cond_node] = { node }; | |||
} else { | |||
iter->second.emplace_back(node); | |||
} | |||
switch_nodes_.emplace_back(node); | |||
NodePtr cond_node = peer_out_anchor->GetOwnerNode(); | |||
auto iter = cond_switch_map.find(cond_node); | |||
if (iter == cond_switch_map.end()) { | |||
cond_switch_map[cond_node] = { node }; | |||
} else { | |||
iter->second.emplace_back(node); | |||
} | |||
switch_nodes_.emplace_back(node); | |||
} | |||
MarkCycleDependence(cond_switch_map); | |||
@@ -241,10 +242,6 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou | |||
if (idx == SWITCH_DATA_INPUT) { | |||
peer_data_anchor = peer_out_anchor; | |||
} else { | |||
if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { | |||
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", switch_node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
peer_cond_anchor = peer_out_anchor; | |||
} | |||
} | |||
@@ -254,15 +251,14 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou | |||
/// | |||
/// @brief Find Switch cond input | |||
/// @param [in] pass_switch_flag | |||
/// @param [out] peer_cond_anchor | |||
/// @return Status | |||
/// | |||
Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { | |||
Status SwitchToStreamSwitchPass::FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor) { | |||
NodePtr tmp_node = nullptr; | |||
string type; | |||
bool need_pass_type = true; | |||
while (need_pass_type) { | |||
std::string type; | |||
bool pass_flag = true; | |||
while (pass_flag) { | |||
if (tmp_node == nullptr) { | |||
tmp_node = peer_cond_anchor->GetOwnerNode(); | |||
} else { | |||
@@ -274,7 +270,7 @@ Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutD | |||
} | |||
GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); | |||
need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); | |||
pass_flag = ((type == SWITCH) || (type == REFSWITCH)); | |||
} | |||
return SUCCESS; | |||
@@ -369,7 +365,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ | |||
} | |||
} else { | |||
int64_t switch_group_id = GetGroupId(stream_switch); | |||
map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map; | |||
std::map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map; | |||
std::list<NodePtr> false_node_list; | |||
std::list<NodePtr> true_node_list; | |||
std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list; | |||
@@ -389,7 +385,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ | |||
/// @return group_id | |||
/// | |||
int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||
string tailing_optimization_option; | |||
std::string tailing_optimization_option; | |||
bool is_tailing_optimization = false; | |||
if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { | |||
// "1" means it's True from frontend option | |||
@@ -400,7 +396,7 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||
return 0; | |||
} | |||
string hccl_group_id; | |||
std::string hccl_group_id; | |||
if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { | |||
GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); | |||
return 0; | |||
@@ -432,6 +428,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | |||
OutDataAnchorPtr peer_cond_anchor = iter->first; | |||
GE_CHECK_NOTNULL(peer_cond_anchor); | |||
NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); | |||
GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); | |||
@@ -549,6 +546,7 @@ NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, con | |||
NodePtr cast_node = graph->AddNode(cast_desc); | |||
GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); | |||
// Cast node has and only has one input | |||
GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); | |||
return cast_node; | |||
@@ -614,24 +612,24 @@ Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_no | |||
return INTERNAL_ERROR; | |||
} | |||
for (const NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { | |||
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), | |||
for (const NodePtr &in_ctrl_node : switch_node->GetInControlNodes()) { | |||
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), | |||
"Remove ctl edge failed."); | |||
GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { | |||
GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||
GE_IF_BOOL_EXEC(!in_ctrl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { | |||
GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||
"Add ctl edge failed."); | |||
}); | |||
GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); | |||
if (same_cond_switch.count(in_ctl_node) > 0) { | |||
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||
GE_IF_BOOL_EXEC(in_ctrl_node->GetType() != STREAMSWITCH, continue); | |||
if (same_cond_switch.count(in_ctrl_node) > 0) { | |||
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||
"Remove ctl edge failed."); | |||
continue; | |||
} | |||
auto find_res1 = switch_node_map_.find(in_ctl_node); | |||
auto find_res1 = switch_node_map_.find(in_ctrl_node); | |||
GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { | |||
GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); | |||
GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctrl_node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
}); | |||
auto find_res2 = find_res1->second.find(orig_switch_name); | |||
@@ -131,11 +131,10 @@ class SwitchToStreamSwitchPass : public GraphPass { | |||
/// | |||
/// @brief Find Switch cond input | |||
/// @param [in] pass_switch_flag | |||
/// @param [out] peer_cond_anchor | |||
/// @return Status | |||
/// | |||
Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor); | |||
Status FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor); | |||
/// | |||
/// @brief Create StreamSwitch Node | |||