|
|
|
@@ -0,0 +1,66 @@ |
|
|
|
/**
|
|
|
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include "external/ge/ge_api.h"
|
|
|
|
#include "ge_running_env/ge_running_env_faker.h"
|
|
|
|
#include "ge_graph_dsl/graph_dsl.h"
|
|
|
|
#include "ge_graph_dsl/assert/graph_assert.h"
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
class TestModelDeploySchedule : public testing::Test {
|
|
|
|
protected:
|
|
|
|
void SetUp() { ge_env.InstallDefault(); }
|
|
|
|
void TearDown() {}
|
|
|
|
GeRunningEnvFaker ge_env;
|
|
|
|
};
|
|
|
|
|
|
|
|
TEST_F(TestModelDeploySchedule, test_data_slice) {
|
|
|
|
DEF_GRAPH(g1) {
|
|
|
|
auto data = std::make_shared<OpDesc>("data1", DATA);
|
|
|
|
auto var = std::make_shared<OpDesc>("var1", VARIABLE);
|
|
|
|
auto conv = std::make_shared<OpDesc>("conv1", CONV2D);
|
|
|
|
|
|
|
|
// data output 0
|
|
|
|
auto data_output_desc = data->MutableOutputDesc(0);
|
|
|
|
std::vector<std::vector<int64_t>> cut_attr = {{0, 0, 0, 0}};
|
|
|
|
ge::AttrUtils::SetListListInt(data_output_desc, "cut_info", cut_attr);
|
|
|
|
// conv input 0
|
|
|
|
auto conv_input_desc = conv->MutableInputDesc(0);
|
|
|
|
ge::AttrUtils::SetListListInt(conv_input_desc, "cut_info", {{1,0,0,0}, {0,0,0,0}});
|
|
|
|
// conv input 1
|
|
|
|
auto input_desc = conv->MutableInputDesc(1);
|
|
|
|
ge::AttrUtils::SetListListInt(input_desc, "cut_info", {{0, 0, 0, 0}, {0, 0, 1, 0}});
|
|
|
|
CHAIN(NODE(data)->NODE(conv));
|
|
|
|
CHAIN(NODE(var)->NODE(conv));
|
|
|
|
// CHAIN(NODE("data1", DATA)->NODE("conv1", CONV2D));
|
|
|
|
// CHAIN(NODE("var1", VARIABLE)->NODE("conv1"));
|
|
|
|
};
|
|
|
|
|
|
|
|
map<AscendString, AscendString> options;
|
|
|
|
// TODO: add option to enable mds pass
|
|
|
|
Session session(options);
|
|
|
|
session.AddGraph(1, ToGeGraph(g1), options);
|
|
|
|
std::vector<InputTensorInfo> inputs;
|
|
|
|
auto ret = session.BuildGraph(1, inputs);
|
|
|
|
|
|
|
|
EXPECT_EQ(ret, SUCCESS);
|
|
|
|
CHECK_GRAPH(PreRunAfterBuild) {
|
|
|
|
ASSERT_EQ(graph->GetName(), "g1_1");
|
|
|
|
ASSERT_EQ(graph->GetAllNodesSize(), 4);
|
|
|
|
};
|
|
|
|
}
|
|
|
|
} // namespace ge |