|
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include <gtest/gtest.h>
- #include "external/ge/ge_api.h"
- #include "ge_running_env/ge_running_env_faker.h"
- #include "ge_graph_dsl/graph_dsl.h"
- #include "ge_graph_dsl/assert/graph_assert.h"
-
- namespace ge {
- class TestModelDeploySchedule : public testing::Test {
- protected:
- void SetUp() { ge_env.InstallDefault(); }
- void TearDown() {}
- GeRunningEnvFaker ge_env;
- };
-
- TEST_F(TestModelDeploySchedule, test_data_slice) {
- DEF_GRAPH(g1) {
- auto data = std::make_shared<OpDesc>("data1", DATA);
- auto var = std::make_shared<OpDesc>("var1", VARIABLE);
- auto conv = std::make_shared<OpDesc>("conv1", CONV2D);
-
- // data output 0
- auto data_output_desc = data->MutableOutputDesc(0);
- std::vector<std::vector<int64_t>> cut_attr = {{0, 0, 0, 0}};
- ge::AttrUtils::SetListListInt(data_output_desc, "cut_info", cut_attr);
- // conv input 0
- auto conv_input_desc = conv->MutableInputDesc(0);
- ge::AttrUtils::SetListListInt(conv_input_desc, "cut_info", {{1,0,0,0}, {0,0,0,0}});
- // conv input 1
- auto input_desc = conv->MutableInputDesc(1);
- ge::AttrUtils::SetListListInt(input_desc, "cut_info", {{0, 0, 0, 0}, {0, 0, 1, 0}});
- CHAIN(NODE(data)->NODE(conv));
- CHAIN(NODE(var)->NODE(conv));
- // CHAIN(NODE("data1", DATA)->NODE("conv1", CONV2D));
- // CHAIN(NODE("var1", VARIABLE)->NODE("conv1"));
- };
-
- map<AscendString, AscendString> options;
- // TODO: add option to enable mds pass
- Session session(options);
- session.AddGraph(1, ToGeGraph(g1), options);
- std::vector<InputTensorInfo> inputs;
- auto ret = session.BuildGraph(1, inputs);
-
- EXPECT_EQ(ret, SUCCESS);
- CHECK_GRAPH(PreRunAfterBuild) {
- ASSERT_EQ(graph->GetName(), "g1_1");
- ASSERT_EQ(graph->GetAllNodesSize(), 4);
- };
- }
- } // namespace ge
|