|
@@ -19,7 +19,8 @@ |
|
|
#include <string> |
|
|
#include <string> |
|
|
|
|
|
|
|
|
#define private public |
|
|
#define private public |
|
|
|
|
|
|
|
|
|
|
|
#include "inc/graph/ge_local_context.h" |
|
|
|
|
|
#include "inc/external/ge/ge_api_types.h" |
|
|
#include "common/ge_inner_error_codes.h" |
|
|
#include "common/ge_inner_error_codes.h" |
|
|
#include "inc/pass_manager.h" |
|
|
#include "inc/pass_manager.h" |
|
|
#include "utils/graph_utils.h" |
|
|
#include "utils/graph_utils.h" |
|
@@ -225,6 +226,70 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { |
|
|
output_true_node_->GetOpDesc()->SetIsInputConst({false}); |
|
|
output_true_node_->GetOpDesc()->SetIsInputConst({false}); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void BuildDefaultGraph3() { |
|
|
|
|
|
/// input |
|
|
|
|
|
/// \ |
|
|
|
|
|
/// sqrt pred |
|
|
|
|
|
/// \ / |
|
|
|
|
|
/// Switch |
|
|
|
|
|
/// | | |
|
|
|
|
|
/// F T ------ |
|
|
|
|
|
/// / \_/_ \ |
|
|
|
|
|
/// / / \ \ |
|
|
|
|
|
/// Merge sqrt2 sqrt3 |
|
|
|
|
|
/// / \ \ |
|
|
|
|
|
/// sqrt1 \ relu |
|
|
|
|
|
/// \ \ |
|
|
|
|
|
/// \ sqrt4 |
|
|
|
|
|
/// \ / |
|
|
|
|
|
/// Merge1 |
|
|
|
|
|
input_node_ = NewNode("input", RELU, 0, 1); |
|
|
|
|
|
AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); |
|
|
|
|
|
pred_node_ = NewNode("pred", GREATER, 2, 1); |
|
|
|
|
|
sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); |
|
|
|
|
|
cast_node_ = NewNode("cast", CAST, 2, 2); |
|
|
|
|
|
|
|
|
|
|
|
switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); |
|
|
|
|
|
AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); |
|
|
|
|
|
switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); |
|
|
|
|
|
AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); |
|
|
|
|
|
output_false_node_ = NewNode("false_output", RELU, 1, 2); |
|
|
|
|
|
AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); |
|
|
|
|
|
output_true_node_ = NewNode("true_output", RELU, 1, 2); |
|
|
|
|
|
AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); |
|
|
|
|
|
merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); |
|
|
|
|
|
sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1); |
|
|
|
|
|
AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); |
|
|
|
|
|
sqrt_node2_ = NewNode("sqrt2", SQRT, 1, 1); |
|
|
|
|
|
AttrUtils::SetStr(sqrt_node2_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); |
|
|
|
|
|
sqrt_node3_ = NewNode("sqrt3", SQRT, 1, 1); |
|
|
|
|
|
relu_node_ = NewNode("relu", RELU, 1, 1); |
|
|
|
|
|
sqrt_node4_ = NewNode("sqrt4", SQRT, 1, 1); |
|
|
|
|
|
AttrUtils::SetStr(sqrt_node4_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); |
|
|
|
|
|
merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1); |
|
|
|
|
|
|
|
|
|
|
|
GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); |
|
|
|
|
|
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); |
|
|
|
|
|
|
|
|
|
|
|
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); |
|
|
|
|
|
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), sqrt_node2_->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), sqrt_node3_->GetInDataAnchor(0)); |
|
|
|
|
|
|
|
|
|
|
|
GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(sqrt_node3_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node4_->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(sqrt_node2_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(0)); |
|
|
|
|
|
GraphUtils::AddEdge(sqrt_node4_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(1)); |
|
|
|
|
|
output_false_node_->GetOpDesc()->SetIsInputConst({false}); |
|
|
|
|
|
output_true_node_->GetOpDesc()->SetIsInputConst({false}); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
ComputeGraphPtr graph_; |
|
|
ComputeGraphPtr graph_; |
|
|
ComputeGraphPtr sub_graph_; |
|
|
ComputeGraphPtr sub_graph_; |
|
|
GeTensorDescPtr default_tensor_desc_; |
|
|
GeTensorDescPtr default_tensor_desc_; |
|
@@ -235,6 +300,9 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { |
|
|
NodePtr cast_node1_; |
|
|
NodePtr cast_node1_; |
|
|
NodePtr sqrt_node_; |
|
|
NodePtr sqrt_node_; |
|
|
NodePtr sqrt_node1_; |
|
|
NodePtr sqrt_node1_; |
|
|
|
|
|
NodePtr sqrt_node2_; |
|
|
|
|
|
NodePtr sqrt_node3_; |
|
|
|
|
|
NodePtr sqrt_node4_; |
|
|
NodePtr input_node_; |
|
|
NodePtr input_node_; |
|
|
NodePtr input_node1_; |
|
|
NodePtr input_node1_; |
|
|
NodePtr switch_node_t; |
|
|
NodePtr switch_node_t; |
|
@@ -278,6 +346,16 @@ TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) { |
|
|
EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor())); |
|
|
EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor())); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph3) { |
|
|
|
|
|
std::map<std::string, std::string> options; |
|
|
|
|
|
options.emplace(OPTION_GRAPH_RUN_MODE, "1"); |
|
|
|
|
|
GetThreadLocalContext().SetGraphOption(options); |
|
|
|
|
|
BuildDefaultGraph3(); |
|
|
|
|
|
auto ret = pass_.Run(graph_); |
|
|
|
|
|
EXPECT_EQ(ret, GRAPH_SUCCESS); |
|
|
|
|
|
EXPECT_EQ(true, merge_node1_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor())); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { |
|
|
TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { |
|
|
BuildDefaultGraph1(); |
|
|
BuildDefaultGraph1(); |
|
|
NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true); |
|
|
NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true); |
|
|