|
|
@@ -83,12 +83,16 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) { |
|
|
|
graphStatus TestFunc(Operator &op) { return 0; } |
|
|
|
graphStatus TestFunc1(Operator &op) { return 1; } |
|
|
|
TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { |
|
|
|
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("graph_name"); |
|
|
|
auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); |
|
|
|
OperatorFactoryImpl::RegisterInferFormatFunc("Add", TestFunc); |
|
|
|
shared_ptr<OpDesc> op_desc = make_shared<OpDesc>("add", "add"); |
|
|
|
compute_graph->AddNode(op_desc); |
|
|
|
GeGenerator generator; |
|
|
|
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc), SUCCESS); |
|
|
|
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc, graph), SUCCESS); |
|
|
|
shared_ptr<OpDesc> op_desc1 = make_shared<OpDesc>("Add", "Add"); |
|
|
|
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1), SUCCESS); |
|
|
|
compute_graph->AddNode(op_desc1); |
|
|
|
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc1, graph), SUCCESS); |
|
|
|
OperatorFactoryImpl::RegisterInferFormatFunc("MatMulV2", TestFunc1); |
|
|
|
shared_ptr<OpDesc> op_desc2 = make_shared<OpDesc>("MatMulV2", "MatMulV2"); |
|
|
|
GeTensorDesc tensor_desc; |
|
|
@@ -99,7 +103,8 @@ TEST_F(UtestGeGenerator, test_infer_format_for_single_op) { |
|
|
|
EXPECT_EQ(op_desc2->AddInputDesc(tensor_desc), GRAPH_SUCCESS); |
|
|
|
EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); |
|
|
|
EXPECT_EQ(op_desc2->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); |
|
|
|
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2), FAILED); |
|
|
|
compute_graph->AddNode(op_desc2); |
|
|
|
EXPECT_EQ(generator.InferFormatForSingleOp(op_desc2, graph), FAILED); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(UtestGeGenerator, test_build_single_op_online) { |
|
|
|