Browse Source

update metadef

pull/1363/head
gengchao4@huawei.com 4 years ago
parent
commit
d6e40f6a7d
1 changed files with 66 additions and 0 deletions
  1. +66
    -0
      tests/st/testcase/test_model_deploy_schedule.cc

+ 66
- 0
tests/st/testcase/test_model_deploy_schedule.cc View File

@@ -0,0 +1,66 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include "external/ge/ge_api.h"
#include "ge_running_env/ge_running_env_faker.h"
#include "ge_graph_dsl/graph_dsl.h"
#include "ge_graph_dsl/assert/graph_assert.h"
namespace ge {
class TestModelDeploySchedule : public testing::Test {
protected:
void SetUp() { ge_env.InstallDefault(); }
void TearDown() {}
GeRunningEnvFaker ge_env;
};
TEST_F(TestModelDeploySchedule, test_data_slice) {
DEF_GRAPH(g1) {
auto data = std::make_shared<OpDesc>("data1", DATA);
auto var = std::make_shared<OpDesc>("var1", VARIABLE);
auto conv = std::make_shared<OpDesc>("conv1", CONV2D);
// data output 0
auto data_output_desc = data->MutableOutputDesc(0);
std::vector<std::vector<int64_t>> cut_attr = {{0, 0, 0, 0}};
ge::AttrUtils::SetListListInt(data_output_desc, "cut_info", cut_attr);
// conv input 0
auto conv_input_desc = conv->MutableInputDesc(0);
ge::AttrUtils::SetListListInt(conv_input_desc, "cut_info", {{1,0,0,0}, {0,0,0,0}});
// conv input 1
auto input_desc = conv->MutableInputDesc(1);
ge::AttrUtils::SetListListInt(input_desc, "cut_info", {{0, 0, 0, 0}, {0, 0, 1, 0}});
CHAIN(NODE(data)->NODE(conv));
CHAIN(NODE(var)->NODE(conv));
// CHAIN(NODE("data1", DATA)->NODE("conv1", CONV2D));
// CHAIN(NODE("var1", VARIABLE)->NODE("conv1"));
};
map<AscendString, AscendString> options;
// TODO: add option to enable mds pass
Session session(options);
session.AddGraph(1, ToGeGraph(g1), options);
std::vector<InputTensorInfo> inputs;
auto ret = session.BuildGraph(1, inputs);
EXPECT_EQ(ret, SUCCESS);
CHECK_GRAPH(PreRunAfterBuild) {
ASSERT_EQ(graph->GetName(), "g1_1");
ASSERT_EQ(graph->GetAllNodesSize(), 4);
};
}
} // namespace ge

Loading…
Cancel
Save