@@ -823,7 +823,7 @@ set(PROFILING_MNG_TEST_FILES | |||||
set(HYBRID_TEST_FILES | set(HYBRID_TEST_FILES | ||||
"hybrid/ge_hybrid_unittest.cc" | "hybrid/ge_hybrid_unittest.cc" | ||||
"hybrid/known_node_executor_unittest.cc" | "hybrid/known_node_executor_unittest.cc" | ||||
"hybrid/executor/worker/execution_engine_unittest.cc" | |||||
"hybrid/executor/node_state_unittest.cc" | |||||
"hybrid/executor/subgraph_executor_unittest.cc" | "hybrid/executor/subgraph_executor_unittest.cc" | ||||
"hybrid/executor/worker/execution_engine_unittest.cc" | "hybrid/executor/worker/execution_engine_unittest.cc" | ||||
"hybrid/model/hybrid_model_builder_unittest.cc" | "hybrid/model/hybrid_model_builder_unittest.cc" | ||||
@@ -15,20 +15,17 @@ | |||||
*/ | */ | ||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#define private public | |||||
#define protected public | |||||
#include "graph/partition/dynamic_shape_partition.h" | #include "graph/partition/dynamic_shape_partition.h" | ||||
#include "compute_graph.h" | #include "compute_graph.h" | ||||
#include "inc/framework/common/types.h" | #include "inc/framework/common/types.h" | ||||
#include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#define private public | |||||
#define protected public | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, Format format = FORMAT_NCHW, | GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, Format format = FORMAT_NCHW, | ||||
DataType data_type = DT_FLOAT) { | DataType data_type = DT_FLOAT) { | ||||
GeShape ge_shape{vector<int64_t>(shape)}; | GeShape ge_shape{vector<int64_t>(shape)}; | ||||
@@ -94,4 +91,29 @@ TEST_F(UtestDynamicShapePartition, single_op_scene_success) { | |||||
DynamicShapePartitioner partitioner(graph); | DynamicShapePartitioner partitioner(graph); | ||||
EXPECT_EQ(partitioner.Partition(), SUCCESS); | EXPECT_EQ(partitioner.Partition(), SUCCESS); | ||||
} | } | ||||
TEST_F(UtestDynamicShapePartition, merge_control_flow_group) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("default"); | |||||
AttrUtils::SetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, "session_graph_id"); | |||||
NodePtr data1 = NodeBuilder("data1", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||||
NodePtr data2 = NodeBuilder("data2", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||||
NodePtr merge = NodeBuilder("node2", MERGE).AddInputDesc({1}).AddInputDesc({1}) | |||||
.AddOutputDesc({1}).AddOutputDesc({}).Build(graph); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), merge->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), merge->GetInDataAnchor(1)); | |||||
(void)AttrUtils::SetBool(data1->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||||
(void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); | |||||
(void)AttrUtils::SetBool(data2->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||||
(void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); | |||||
(void)AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||||
(void)AttrUtils::SetInt(merge->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); | |||||
EXPECT_EQ(graph->sub_graph_.size(), 0); | |||||
DynamicShapePartitioner partitioner(graph); | |||||
EXPECT_EQ(partitioner.Partition(), SUCCESS); | |||||
EXPECT_EQ(graph->sub_graph_.size(), 1); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -223,4 +223,17 @@ TEST_F(UtestGraphPreproces, test_update_dtype_mbatch_case) { | |||||
auto data1_output = data1_desc->MutableOutputDesc(0); | auto data1_output = data1_desc->MutableOutputDesc(0); | ||||
EXPECT_EQ(data1_output->GetDataType(), 1); | EXPECT_EQ(data1_output->GetDataType(), 1); | ||||
} | } | ||||
TEST_F(UtestGraphPreproces, test_prepare_dyn_shape) { | |||||
ComputeGraphPtr compute_graph = BuildGraph5(); | |||||
GraphPtr graph_ptr = std::make_shared<Graph>(GraphUtils::CreateGraphFromComputeGraph(compute_graph)); | |||||
GraphNodePtr graph_node = make_shared<GraphNode>(0); | |||||
graph_node->SetComputeGraph(compute_graph); | |||||
graph_node->SetGraph(graph_ptr); | |||||
std::vector<GeTensor> user_input; | |||||
GraphPrepare graph_prepare; | |||||
EXPECT_EQ(graph_prepare.PrepareDynShape(graph_node, user_input, compute_graph, 0), SUCCESS); | |||||
} | |||||
} | } |
@@ -0,0 +1,106 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <gmock/gmock.h> | |||||
#include <vector> | |||||
#define private public | |||||
#define protected public | |||||
#include "hybrid/executor/node_state.h" | |||||
#include "hybrid/executor/subgraph_context.h" | |||||
#include "hybrid/model/graph_item.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
using namespace hybrid; | |||||
class UtestNodeState : public testing::Test { | |||||
protected: | |||||
void SetUp() { | |||||
} | |||||
void TearDown() { | |||||
} | |||||
}; | |||||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
op_desc->SetStreamId(0); | |||||
static int32_t index = 0; | |||||
op_desc->SetId(index++); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
TensorUtils::SetSize(tensor, 64); | |||||
vector<int64_t> input_offset; | |||||
for (int i = 0; i < in_num; i++) { | |||||
op_desc->AddInputDesc(tensor); | |||||
input_offset.emplace_back(index * 64 + i * 64); | |||||
} | |||||
op_desc->SetInputOffset(input_offset); | |||||
vector<int64_t> output_offset; | |||||
for (int i = 0; i < out_num; i++) { | |||||
op_desc->AddOutputDesc(tensor); | |||||
output_offset.emplace_back(index * 64 + in_num * 64 + i * 64); | |||||
} | |||||
op_desc->SetOutputOffset(output_offset); | |||||
op_desc->SetWorkspace({}); | |||||
op_desc->SetWorkspaceBytes({}); | |||||
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||||
return graph.AddNode(op_desc); | |||||
} | |||||
TEST_F(UtestNodeState, merge_await_shapes_ready) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
const auto data0 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
const auto data1 = CreateNode(*graph, "data1", DATA, 1, 1); | |||||
const auto merge1 = CreateNode(*graph, "merge", STREAMMERGE, 2, 2); | |||||
const auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||||
GraphUtils::AddEdge(data0->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
GraphItem graph_item; | |||||
GraphExecutionContext graph_context; | |||||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
std::unique_ptr<NodeItem> node_item; | |||||
NodeItem::Create(merge1, node_item); | |||||
NodeState node_state(*node_item, &subgraph_context); | |||||
// Not dynamic. | |||||
ASSERT_EQ(node_state.shape_inference_state_.AwaitShapesReady(graph_context), SUCCESS); | |||||
// Not set merge index. | |||||
node_item->is_dynamic = true; | |||||
ASSERT_EQ(node_state.shape_inference_state_.AwaitShapesReady(graph_context), FAILED); | |||||
// merge index out of bound. | |||||
AttrUtils::SetInt(merge1->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, 3); | |||||
ASSERT_EQ(node_state.shape_inference_state_.AwaitShapesReady(graph_context), FAILED); | |||||
AttrUtils::SetInt(merge1->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, 1); | |||||
ASSERT_EQ(node_state.shape_inference_state_.AwaitShapesReady(graph_context), SUCCESS); | |||||
} | |||||
} // namespace ge |