Browse Source

fix parallel group pass

tags/v1.5.1
陈华 3 years ago
parent
commit
2874ec935f
3 changed files with 119 additions and 20 deletions
  1. +39
    -19
      ge/graph/passes/parallel_group_pass.cc
  2. +1
    -0
      ge/graph/passes/parallel_group_pass.h
  3. +79
    -1
      tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc

+ 39
- 19
ge/graph/passes/parallel_group_pass.cc View File

@@ -15,7 +15,7 @@
*/

#include "graph/passes/parallel_group_pass.h"
#include <queue>
#include "framework/common/debug/ge_log.h"
#include "common/ge/ge_util.h"
#include "framework/common/ge_inner_error_codes.h"
@@ -299,24 +299,19 @@ Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cu
for (const auto &switch_node : cur_itr->second.first) {
int64_t pre_id = pre_node->GetOpDesc()->GetId();
int64_t switch_id = switch_node->GetOpDesc()->GetId();
// avoid ring
if (pre_id > switch_id) {
auto merge_node = cur_itr->second.second;
if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) {
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
return FAILED;
}
} else {
if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) {
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
return FAILED;
}
NodePtr first_node = pre_node;
NodePtr second_node = switch_node;
if (pre_id > switch_id && IsIndirectConnect(switch_node, pre_node)) {
// avoid ring, merge->pre_node
first_node = cur_itr->second.second;
second_node = pre_node;
}
if (AddCtrlEdge(first_node, second_node) != SUCCESS) {
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
first_node->GetName().c_str(), second_node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
first_node->GetName().c_str(), second_node->GetName().c_str());
return FAILED;
}
}
} else {
@@ -345,4 +340,29 @@ bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) {
return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) &&
stream_switch_type == kLoopType);
}

bool ParallelGroupPass::IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b) {
if (node_a == nullptr || node_b == nullptr) {
GELOGW("node_a or node_b is nullptr.");
return false;
}
int64_t end_id = node_b->GetOpDesc()->GetId();
std::queue<NodePtr> nodes;
nodes.push(node_a);
while (!nodes.empty()) {
NodePtr tmp_node = nodes.front();
nodes.pop();
if (tmp_node == nullptr || tmp_node->GetOpDesc() == nullptr ||
tmp_node->GetOpDesc()->GetId() > end_id) {
continue;
}
if (tmp_node == node_b) {
return true;
}
for (const auto &out_node : tmp_node->GetOutAllNodes()) {
nodes.push(out_node);
}
}
return false;
}
} // namespace ge

+ 1
- 0
ge/graph/passes/parallel_group_pass.h View File

@@ -48,6 +48,7 @@ class ParallelGroupPass : public GraphPass {

bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc);
bool IsWhileStreamSwitch(OpDescPtr switch_op_desc);
bool IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H

+ 79
- 1
tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc View File

@@ -19,7 +19,8 @@
#include <string>

#define private public

#include "inc/graph/ge_local_context.h"
#include "inc/external/ge/ge_api_types.h"
#include "common/ge_inner_error_codes.h"
#include "inc/pass_manager.h"
#include "utils/graph_utils.h"
@@ -225,6 +226,70 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test {
output_true_node_->GetOpDesc()->SetIsInputConst({false});
}

void BuildDefaultGraph3() {
/// input
/// \
/// sqrt pred
/// \ /
/// Switch
/// | |
/// F T ------
/// / \_/_ \
/// / / \ \
/// Merge sqrt2 sqrt3
/// / \ \
/// sqrt1 \ relu
/// \ \
/// \ sqrt4
/// \ /
/// Merge1
input_node_ = NewNode("input", RELU, 0, 1);
AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
pred_node_ = NewNode("pred", GREATER, 2, 1);
sqrt_node_ = NewNode("sqrt", SQRT, 1, 1);
cast_node_ = NewNode("cast", CAST, 2, 2);

switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1);
AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1);
AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
output_false_node_ = NewNode("false_output", RELU, 1, 2);
AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
output_true_node_ = NewNode("true_output", RELU, 1, 2);
AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
merge_node_ = NewNode("merge", STREAMMERGE, 2, 1);
sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1);
AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
sqrt_node2_ = NewNode("sqrt2", SQRT, 1, 1);
AttrUtils::SetStr(sqrt_node2_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
sqrt_node3_ = NewNode("sqrt3", SQRT, 1, 1);
relu_node_ = NewNode("relu", RELU, 1, 1);
sqrt_node4_ = NewNode("sqrt4", SQRT, 1, 1);
AttrUtils::SetStr(sqrt_node4_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1);

GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0));
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));

GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), sqrt_node2_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), sqrt_node3_->GetInDataAnchor(0));

GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0));
GraphUtils::AddEdge(sqrt_node3_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node4_->GetInDataAnchor(0));
GraphUtils::AddEdge(sqrt_node2_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(0));
GraphUtils::AddEdge(sqrt_node4_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(1));
output_false_node_->GetOpDesc()->SetIsInputConst({false});
output_true_node_->GetOpDesc()->SetIsInputConst({false});
}

ComputeGraphPtr graph_;
ComputeGraphPtr sub_graph_;
GeTensorDescPtr default_tensor_desc_;
@@ -235,6 +300,9 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test {
NodePtr cast_node1_;
NodePtr sqrt_node_;
NodePtr sqrt_node1_;
NodePtr sqrt_node2_;
NodePtr sqrt_node3_;
NodePtr sqrt_node4_;
NodePtr input_node_;
NodePtr input_node1_;
NodePtr switch_node_t;
@@ -278,6 +346,16 @@ TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) {
EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor()));
}

TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph3) {
std::map<std::string, std::string> options;
options.emplace(OPTION_GRAPH_RUN_MODE, "1");
GetThreadLocalContext().SetGraphOption(options);
BuildDefaultGraph3();
auto ret = pass_.Run(graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
EXPECT_EQ(true, merge_node1_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor()));
}

TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) {
BuildDefaultGraph1();
NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true);


Loading…
Cancel
Save