You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

guarantee_const_pass_unittest.cc 8.7 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <unordered_map>
  18. #include <vector>
  19. #define protected public
  20. #define private public
  21. #include "graph/passes/guarantee_const_pass.h"
  22. #include "../ops_stub.h"
  23. #include "common/ge_inner_error_codes.h"
  24. #include "common/types.h"
  25. #include "graph/debug/ge_attr_define.h"
  26. #include "graph/utils/attr_utils.h"
  27. #include "graph/utils/graph_utils.h"
  28. #include "graph/utils/op_desc_utils.h"
  29. #include "graph/utils/tensor_utils.h"
  30. #include "inc/pass_manager.h"
  31. #undef protected
  32. #undef private
  33. using namespace domi;
  34. using namespace testing;
  35. using namespace ge;
  36. using namespace std;
  37. // To check whether the shape of output is correct or not
  38. #define TEST_OPERATOR(op_, input_shapes, output_shapes) \
  39. { \
  40. auto op = op_; \
  41. for (auto input_pair : input_shapes) SetInputShape(op, input_pair.first, input_pair.second); \
  42. op.InferShapeAndType(); \
  43. for (auto output_pair : output_shapes) CheckOutputShape(op, output_pair.first, output_pair.second); \
  44. }
  45. #define LOOP_VEC(v) for (unsigned i = 0; i < v.size(); i++)
  46. class UTEST_graph_passes_guarantee_const_pass : public testing::Test {
  47. protected:
  48. void SetUp() { init(); }
  49. void TearDown() { destory(); }
  50. private:
  51. void init() { guaranteeConstOpRemovePass = new ::ge::GuaranteeConstPass(); }
  52. void destory() {
  53. delete guaranteeConstOpRemovePass;
  54. guaranteeConstOpRemovePass = NULL;
  55. }
  56. protected:
  57. ge::GuaranteeConstPass *guaranteeConstOpRemovePass;
  58. void SetInputShape(Operator op, string name, vector<int64_t> shape) {
  59. TensorDesc td = op.GetInputDesc(name);
  60. td.SetShape(ge::Shape(shape));
  61. op.UpdateInputDesc(name, td);
  62. }
  63. void CheckOutputShape(Operator op, string name, vector<int64_t> shape) {
  64. ge::Shape s = op.GetOutputDesc(name).GetShape();
  65. EXPECT_EQ(s.GetDims().size(), shape.size());
  66. LOOP_VEC(shape) EXPECT_EQ(s.GetDim(i), shape[i]);
  67. }
  68. /// Init the node which will be passed in graph, isMultiInput represents whether using more than
  69. /// one data anchor or not.
  70. NodePtr init_node(ComputeGraphPtr graph, vector<int64_t> dims_vec, vector<int32_t> data_vec, bool isMultiInput,
  71. string type) {
  72. // middle
  73. OpDescPtr op_def = std::make_shared<OpDesc>("op_def", type);
  74. OpDescPtr in_op_def = std::make_shared<OpDesc>("op_def_in", "test");
  75. OpDescPtr out_op_def = std::make_shared<OpDesc>("op_def_out", "test");
  76. OpDescPtr another_in_op_def = std::make_shared<OpDesc>("another_op_def_in", "test");
  77. // whether using another input data anchor or not
  78. if (isMultiInput) {
  79. vector<bool> is_input_const_vec = {true, true};
  80. op_def->SetIsInputConst(is_input_const_vec);
  81. AttrUtils::SetInt(op_def, ge::ATTR_NAME_T, (int64_t)DT_INT32);
  82. }
  83. // input tensor;
  84. GeTensorDesc tensor_desc(GeShape(dims_vec), FORMAT_NCHW, DT_INT32);
  85. ge::ConstGeTensorPtr constTensor =
  86. std::make_shared<GeTensor>(tensor_desc, (uint8_t *)&data_vec[0], data_vec.size() * sizeof(int32_t));
  87. ge::AttrUtils::SetTensor(in_op_def, ge::ATTR_NAME_WEIGHTS, constTensor);
  88. op_def->AddInputDesc(tensor_desc);
  89. // whether using another input data anchor or not
  90. if (isMultiInput) {
  91. vector<int64_t> dims_vec_another = {6};
  92. vector<int32_t> data_vec_another = {1, 2, 3, 4, 5, 6};
  93. GeTensorDesc another_tensor_desc(GeShape(dims_vec_another), FORMAT_NCHW, DT_INT32);
  94. ge::ConstGeTensorPtr constTensor_another = std::make_shared<GeTensor>(
  95. another_tensor_desc, (uint8_t *)&data_vec_another[0], data_vec_another.size() * sizeof(int32_t));
  96. ge::AttrUtils::SetTensor(another_in_op_def, ge::ATTR_NAME_WEIGHTS, constTensor_another);
  97. op_def->AddInputDesc(another_tensor_desc);
  98. another_in_op_def->AddOutputDesc(another_tensor_desc);
  99. out_op_def->AddInputDesc(another_tensor_desc);
  100. }
  101. GeTensorDesc tensor_desc_out(GeShape(dims_vec), FORMAT_NCHW, DT_INT32);
  102. op_def->AddOutputDesc(tensor_desc_out);
  103. in_op_def->AddOutputDesc(tensor_desc);
  104. // add attr of out_node
  105. vector<bool> is_output_const(3, false);
  106. is_output_const[0] = true;
  107. out_op_def->SetIsInputConst(is_output_const);
  108. out_op_def->AddInputDesc(tensor_desc);
  109. // Add node
  110. NodePtr in_node = graph->AddNode(in_op_def);
  111. NodePtr node = graph->AddNode(op_def);
  112. NodePtr out_node = graph->AddNode(out_op_def);
  113. // Add edge
  114. GraphUtils::AddEdge(in_node->GetOutDataAnchor(0), node->GetInDataAnchor(0));
  115. GraphUtils::AddEdge(node->GetOutDataAnchor(0), out_node->GetInDataAnchor(0));
  116. // when need multi input nodes (which to verify the isolate node function)
  117. if (isMultiInput) {
  118. NodePtr another_in_node = graph->AddNode(another_in_op_def);
  119. GraphUtils::AddEdge(another_in_node->GetOutDataAnchor(0), node->GetInDataAnchor(1));
  120. }
  121. return node;
  122. }
  123. };
  124. TEST_F(UTEST_graph_passes_guarantee_const_pass, not_changed) {
  125. // the original type of op is not guarantee_const
  126. string type = SIZE;
  127. // input tensor
  128. vector<int64_t> dims_vec = {6};
  129. vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
  130. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  131. NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
  132. ge::Status ret = guaranteeConstOpRemovePass->Run(node);
  133. EXPECT_EQ(domi::SUCCESS, ret);
  134. }
  135. TEST_F(UTEST_graph_passes_guarantee_const_pass, get_origenal_type_fail) {
  136. string type = GUARANTEECONST;
  137. // input tensor
  138. vector<int64_t> dims_vec = {6};
  139. vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
  140. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  141. NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
  142. // change the type
  143. string type2 = "FrameworkOp";
  144. node->GetOpDesc()->SetType(type2);
  145. ge::Status ret = guaranteeConstOpRemovePass->Run(node);
  146. // EXPECT_EQ(ge::SUCCESS, ret);
  147. }
  148. TEST_F(UTEST_graph_passes_guarantee_const_pass, Int32Success_6) {
  149. // input tensor
  150. string type = GUARANTEECONST;
  151. vector<int64_t> dims_vec = {6};
  152. vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
  153. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  154. NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
  155. // when input tensor is [1, 2, 3, 4, 5, 6], return success
  156. ge::Status output = guaranteeConstOpRemovePass->Run(node);
  157. EXPECT_EQ(ge::SUCCESS, output);
  158. }
  159. TEST_F(UTEST_graph_passes_guarantee_const_pass, Int32Success_2_3) {
  160. // input tensor
  161. string type = GUARANTEECONST;
  162. vector<int64_t> dims_vec = {2, 3};
  163. vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
  164. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  165. NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
  166. // when input tensor is [[1, 2, 3], [4, 5, 6]], return success
  167. ge::Status output = guaranteeConstOpRemovePass->Run(node);
  168. EXPECT_EQ(ge::SUCCESS, output);
  169. }
  170. TEST_F(UTEST_graph_passes_guarantee_const_pass, IsolateNodeFailed) {
  171. // input tensor
  172. string type = GUARANTEECONST;
  173. vector<int64_t> dims_vec = {2, 3};
  174. vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
  175. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  176. // add another input node
  177. NodePtr node = init_node(graph, dims_vec, data_vec, true, type);
  178. // when there are more than one input anchors, return failed
  179. ge::Status output = guaranteeConstOpRemovePass->Run(node);
  180. EXPECT_EQ(ge::PARAM_INVALID, output);
  181. }
  182. // IR test, the shape and data type of input should be equal to the shape and data type of output
  183. TEST_F(UTEST_graph_passes_guarantee_const_pass, ir_infer_shape) {
  184. auto input = unordered_map<string, vector<int64_t>>({
  185. {"x", {3, 5, 3, 4}},
  186. });
  187. auto output = unordered_map<string, vector<int64_t>>({
  188. {"y", {3, 5, 3, 4}},
  189. });
  190. auto guaranteeConst = op::GuaranteeConst("guaranteeconst");
  191. TEST_OPERATOR(guaranteeConst, input, output);
  192. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示

Contributors (1)