| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include "external/ge/ge_api.h" | |||||
| #include "ge_running_env/ge_running_env_faker.h" | |||||
| #include "ge_graph_dsl/graph_dsl.h" | |||||
| #include "ge_graph_dsl/assert/graph_assert.h" | |||||
| namespace ge { | |||||
| class TestModelDeploySchedule : public testing::Test { | |||||
| protected: | |||||
| void SetUp() { ge_env.InstallDefault(); } | |||||
| void TearDown() {} | |||||
| GeRunningEnvFaker ge_env; | |||||
| }; | |||||
| TEST_F(TestModelDeploySchedule, test_data_slice) { | |||||
| DEF_GRAPH(g1) { | |||||
| auto data = std::make_shared<OpDesc>("data1", DATA); | |||||
| auto var = std::make_shared<OpDesc>("var1", VARIABLE); | |||||
| auto conv = std::make_shared<OpDesc>("conv1", CONV2D); | |||||
| // data output 0 | |||||
| auto data_output_desc = data->MutableOutputDesc(0); | |||||
| std::vector<std::vector<int64_t>> cut_attr = {{0, 0, 0, 0}}; | |||||
| ge::AttrUtils::SetListListInt(data_output_desc, "cut_info", cut_attr); | |||||
| // conv input 0 | |||||
| auto conv_input_desc = conv->MutableInputDesc(0); | |||||
| ge::AttrUtils::SetListListInt(conv_input_desc, "cut_info", {{1,0,0,0}, {0,0,0,0}}); | |||||
| // conv input 1 | |||||
| auto input_desc = conv->MutableInputDesc(1); | |||||
| ge::AttrUtils::SetListListInt(input_desc, "cut_info", {{0, 0, 0, 0}, {0, 0, 1, 0}}); | |||||
| CHAIN(NODE(data)->NODE(conv)); | |||||
| CHAIN(NODE(var)->NODE(conv)); | |||||
| // CHAIN(NODE("data1", DATA)->NODE("conv1", CONV2D)); | |||||
| // CHAIN(NODE("var1", VARIABLE)->NODE("conv1")); | |||||
| }; | |||||
| map<AscendString, AscendString> options; | |||||
| // TODO: add option to enable mds pass | |||||
| Session session(options); | |||||
| session.AddGraph(1, ToGeGraph(g1), options); | |||||
| std::vector<InputTensorInfo> inputs; | |||||
| auto ret = session.BuildGraph(1, inputs); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| CHECK_GRAPH(PreRunAfterBuild) { | |||||
| ASSERT_EQ(graph->GetName(), "g1_1"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 4); | |||||
| }; | |||||
| } | |||||
| } // namespace ge | |||||