|
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include <gtest/gtest.h>
- #include "external/ge/ge_api.h"
- #include "ge_running_env/ge_running_env_faker.h"
- #include "ge_graph_dsl/graph_dsl.h"
- #include "ge_graph_dsl/assert/graph_assert.h"
-
- namespace ge {
- class TestModelDeploySchedule : public testing::Test {
- protected:
- void SetUp() { ge_env.InstallDefault(); }
- void TearDown() {}
- GeRunningEnvFaker ge_env;
- };
-
- TEST_F(TestModelDeploySchedule, test_data_slice) {
- DEF_GRAPH(g1) {
- auto conv = std::make_shared<OpDesc>("conv1", CONV2D);
- // conv input 0
- auto conv_input_desc = conv->MutableInputDesc(0);
- ge::AttrUtils::SetListListInt(conv_input_desc, "cut_info", {{1,0,0,0}, {0,0,0,0}});
- // conv input 1
- auto input_desc = conv->MutableInputDesc(1);
- ge::AttrUtils::SetListListInt(input_desc, "cut_info", {{0, 0, 0, 0}, {0, 0, 1, 0}});
- CHAIN(NODE("data1", DATA)->NODE(conv));
- CHAIN(NODE("var1", VARIABLE)->NODE(conv));
- };
-
- map<AscendString, AscendString> options;
- // TODO: add option to enable mds pass
- Session session(options);
- session.AddGraph(1, ToGeGraph(g1), options);
- std::vector<InputTensorInfo> inputs;
- auto ret = session.BuildGraph(1, inputs);
-
- EXPECT_EQ(ret, SUCCESS);
-
- DEF_GRAPH(g1_1) {
- auto conv = std::make_shared<OpDesc>("conv1", CONV2D);
- // conv input 0
- auto conv_input_desc = conv->MutableInputDesc(0);
- ge::AttrUtils::SetListListInt(conv_input_desc, "cut_info", {{1, 0, 0, 0}, {0, 0, 0, 0}});
- auto dims = conv_input_desc->GetShape().GetDims();
- dims[0] /= 2;
- conv_input_desc->SetShape(GeShape(dims));
- // conv input 1
- auto input_desc = conv->MutableInputDesc(1);
- ge::AttrUtils::SetListListInt(input_desc, "cut_info", {{0, 0, 0, 0}, {0, 0, 1, 0}});
-
- auto cov2 = OP_CFG(CONV2D).input(0, TensorDesc().attr("cut_info", {{1, 0, 0, 0}, {0, 0, 0, 0}});
- auto slice =OP_CFG(SLICE).Attr("ddd ", 1).TensorDescShape(dims);
- CHAIN(NODE("data1", DATA)->EDGE(0, 0)->NODE("any_slice", slice)->EDGE(0, 0)->NODE(conv));
- CHAIN(NODE("any_var", VARIABLE)->EDGE(0, 0)->NODE("any_slice"));
- CHAIN(NODE("var1", VARIABLE)->EDGE(0, 1)->NODE(conv));
- };
-
- CHECK_GRAPH(PreRunAfterBuild) {
-
- ASSERT_GRAPH_CMP(graph, g1_1);
- ASSERT_EQ(graph->GetName(), "g1_1");
- ASSERT_EQ(graph->GetAllNodesSize(), 4);
- };
- }
- } // namespace ge
|