Browse Source

fix data directlly connect netoutput scene

tags/v1.3.0
wxl 3 years ago
parent
commit
93b6dff0d7
1 changed files with 28 additions and 1 deletions
  1. +28
    -1
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 28
- 1
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -39,7 +39,7 @@
#include "hybrid/common/npu_memory_allocator.h" #include "hybrid/common/npu_memory_allocator.h"
#include "graph/types.h" #include "graph/types.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "graph/testcase/ge_graph/graph_builder_utils.h"
#undef private #undef private
#undef protected #undef protected


@@ -173,6 +173,33 @@ TEST_F(UtestGeHybrid, parse_force_infershape_nodes) {
HybridModelBuilder hybrid_model_builder(hybrid_model); HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS); ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS);
} }
static ComputeGraphPtr BuildDataDirectConnectGraph() {
ge::ut::GraphBuilder builder("subgraph");
auto data = builder.AddNode("Data", "Data", 1, 1);
auto netoutput = builder.AddNode("Netoutput", "Netoutput", 1, 1);

builder.AddDataEdge(data, 0, netoutput, 0);
return builder.GetGraph();
}
TEST_F(UtestGeHybrid, data_direct_connect) {
std::unique_ptr<NodeItem> node_item;
auto root_graph = make_shared<ComputeGraph>("root_graph");
OpDescPtr op_desc = CreateOpDesc("PartitionedCall", "PartitionedCall");
auto node = root_graph->AddNode(op_desc);
auto sub_graph = BuildDataDirectConnectGraph();
sub_graph->SetParentGraph(root_graph);
sub_graph->SetParentNode(node);
node->GetOpDesc()->AddSubgraphName("subgraph");
node->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph");
root_graph->AddSubgraph("subgraph", sub_graph);
std::unique_ptr<NodeItem> new_node;
NodeItem::Create(node, new_node);
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(root_graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get());
ASSERT_EQ(ret, SUCCESS);
}


TEST_F(UtestGeHybrid, index_taskdefs_success) { TEST_F(UtestGeHybrid, index_taskdefs_success) {
// build aicore task // build aicore task


Loading…
Cancel
Save