You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

variable_op_pass_unittest.cc 48 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <memory>
  18. #include <mutex>
  19. #include <thread>
  20. #include <vector>
  21. #include "common/types.h"
  22. #define protected public
  23. #define private public
  24. #include "graph/passes/variable_op_pass.h"
  25. #include "common/op/ge_op_utils.h"
  26. #include "graph/utils/op_desc_utils.h"
  27. #include "graph/utils/attr_utils.h"
  28. #include "graph/utils/graph_utils.h"
  29. #include "graph/op_desc.h"
  30. #include "graph/types.h"
  31. #include "graph/manager/graph_context.h"
  32. #include "graph/optimize/graph_optimize.h"
  33. #include "graph/manager/util/variable_accelerate_ctrl.h"
  34. #include "graph/manager/graph_mem_allocator.h"
  35. #include "graph/manager/graph_var_manager.h"
  36. #include "graph_builder_utils.h"
  37. #include "cce/dnn_struct_base.hpp"
  38. #include "common/formats/format_transfers/format_transfer.h"
  39. #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h"
  40. #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h"
  41. #include "common/formats/format_transfers/datatype_transfer.h"
  42. #undef private
  43. #undef protected
  44. using namespace std;
  45. using namespace ge;
  46. using namespace cce;
  47. class UtestVariableOpPassUnit : public testing::Test {
  48. protected:
  49. void SetUp() {}
  50. void TearDown() {}
  51. // AUTO GEN PLEASE DO NOT MODIFY IT
  52. };
  53. namespace {
  54. /// c
  55. /// var1ref1 --> netoutput1
  56. /// \ /
  57. /// transdata2
  58. /// |
  59. /// assign1
  60. /// / \
  61. /// transdata1 |
  62. /// | |
  63. /// var1 const1
  64. ComputeGraphPtr BuildGraph1() {
  65. auto builder = ut::GraphBuilder("g1");
  66. auto var1 = builder.AddNode("var1", "Variable", 0, 1);
  67. auto const1 =
  68. builder.AddNode("const1", "Const", 0, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector<int64_t>({1, 1, 224, 224, 16}));
  69. auto transdata1 = builder.AddNode("transdata1", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT,
  70. std::vector<int64_t>({1, 1, 224, 224, 16}));
  71. transdata1->GetOpDesc()->MutableInputDesc(0)->SetFormat(FORMAT_NCHW);
  72. transdata1->GetOpDesc()->MutableInputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 3, 224, 224})));
  73. auto assign1 =
  74. builder.AddNode("assign1", "Assign", 2, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector<int64_t>({1, 1, 224, 224, 16}));
  75. auto transdata2 = builder.AddNode("transdata2", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT,
  76. std::vector<int64_t>({1, 1, 224, 224, 16}));
  77. transdata2->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NCHW);
  78. transdata2->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 3, 224, 224})));
  79. auto var1ref1 = builder.AddNode("var1ref1", "Variable", 1, 0);
  80. AttrUtils::SetStr(var1ref1->GetOpDesc(), REF_VAR_SRC_VAR_NAME, "var1");
  81. auto netoutput1 = builder.AddNode("netoutput1", "Netoutput", 2, 0);
  82. builder.AddDataEdge(var1, 0, transdata1, 0);
  83. builder.AddDataEdge(const1, 0, assign1, 1);
  84. builder.AddDataEdge(transdata1, 0, assign1, 0);
  85. builder.AddDataEdge(assign1, 0, transdata2, 0);
  86. builder.AddDataEdge(transdata2, 0, var1ref1, 0);
  87. builder.AddDataEdge(transdata2, 0, netoutput1, 0);
  88. builder.AddControlEdge(var1ref1, netoutput1);
  89. return builder.GetGraph();
  90. }
  91. /// conv1
  92. /// |
  93. /// reshape1
  94. /// |
  95. /// var1
  96. ComputeGraphPtr BuildGraph2() {
  97. auto builder = ut::GraphBuilder("g1");
  98. auto var1 = builder.AddNode("var1", "Variable", 0, 1, FORMAT_ND, DT_FLOAT, std::vector<int64_t>({8 * 8 * 3, 2}));
  99. auto reshape1 =
  100. builder.AddNode("reshape1", "Reshape", 2, 1, FORMAT_HWCN, DT_FLOAT, std::vector<int64_t>({8, 8, 3, 2}));
  101. reshape1->GetOpDesc()->MutableInputDesc(0)->SetFormat(FORMAT_ND);
  102. reshape1->GetOpDesc()->MutableInputDesc(0)->SetShape(GeShape(std::vector<int64_t>({8 * 8 * 3, 2})));
  103. auto conv1 = builder.AddNode("conv1", "Conv2D", 2, 1, FORMAT_HWCN, DT_FLOAT, std::vector<int64_t>({8, 8, 3, 2}));
  104. builder.AddDataEdge(var1, 0, reshape1, 0);
  105. builder.AddDataEdge(reshape1, 0, conv1, 1);
  106. return builder.GetGraph();
  107. }
  108. /// conv1
  109. /// |
  110. /// reformat1
  111. /// |
  112. /// var1
  113. ComputeGraphPtr BuildGraph3() {
  114. auto builder = ut::GraphBuilder("g1");
  115. auto var1 = builder.AddNode("var1", "Variable", 0, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({8, 8, 3, 2}));
  116. auto reformat1 =
  117. builder.AddNode("reformat1", "ReFormat", 1, 1, FORMAT_ND, DT_FLOAT, std::vector<int64_t>({8, 8, 3, 2}));
  118. reformat1->GetOpDesc()->MutableInputDesc(0)->SetFormat(FORMAT_NCHW);
  119. reformat1->GetOpDesc()->MutableInputDesc(0)->SetShape(GeShape(std::vector<int64_t>({8, 8, 3, 2})));
  120. auto conv1 = builder.AddNode("conv1", "Conv2D", 2, 1, FORMAT_ND, DT_FLOAT, std::vector<int64_t>({8, 8, 3, 2}));
  121. builder.AddDataEdge(var1, 0, reformat1, 0);
  122. builder.AddDataEdge(reformat1, 0, conv1, 1);
  123. return builder.GetGraph();
  124. }
  125. class NodeBuilder {
  126. public:
  127. NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
  128. NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  129. ge::DataType data_type = DT_FLOAT) {
  130. op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  131. return *this;
  132. }
  133. NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  134. ge::DataType data_type = DT_FLOAT) {
  135. op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  136. return *this;
  137. }
  138. ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); }
  139. private:
  140. ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  141. ge::DataType data_type = DT_FLOAT) {
  142. GeShape ge_shape{std::vector<int64_t>(shape)};
  143. ge::GeTensorDescPtr tensor_desc = std::make_shared<ge::GeTensorDesc>();
  144. tensor_desc->SetShape(ge_shape);
  145. tensor_desc->SetFormat(format);
  146. tensor_desc->SetDataType(data_type);
  147. return tensor_desc;
  148. }
  149. ge::OpDescPtr op_desc_;
  150. };
  151. std::string var_ref_name_0;
  152. ge::NodePtr CreatVariableRef(ge::NodePtr &final_writable_node, ge::NodePtr &var_node) {
  153. GELOGI("Create VarRef Op: final_writable_node: [%s] var_node: [%s]>>>>", final_writable_node->GetName().c_str(),
  154. var_node->GetName().c_str());
  155. static uint32_t var_ref_count = 0;
  156. std::stringstream var_ref_name;
  157. var_ref_name << "_to_" << final_writable_node->GetName() << "_REF_" << var_ref_count++;
  158. OpDescPtr var_op_desc = var_node->GetOpDesc();
  159. GE_CHK_BOOL_EXEC(var_op_desc != nullptr, return nullptr, "get var opdesc is nullptr");
  160. OpDescPtr var_ref_op_desc = nullptr;
  161. GE_MAKE_SHARED(var_ref_op_desc =
  162. std::make_shared<OpDesc>(var_node->GetName() + var_ref_name.str().c_str(), var_op_desc->GetType()),
  163. return nullptr);
  164. var_ref_op_desc->AddOutputDesc(var_op_desc->GetOutputDesc(0));
  165. var_ref_op_desc->AddInputDesc(var_op_desc->GetOutputDesc(0));
  166. const map<string, ge::GeAttrValue> var_attr_value = var_op_desc->GetAllAttrs();
  167. for (auto const &attrIt : var_attr_value) {
  168. var_ref_op_desc->SetAttr(attrIt.first, attrIt.second);
  169. }
  170. NodePtr var_ref_node = var_node->GetOwnerComputeGraph()->AddNode(var_ref_op_desc);
  171. GE_CHK_BOOL_EXEC(var_ref_node != nullptr, return nullptr, "create var_REF_node failed")
  172. GE_IF_BOOL_EXEC(ge::AttrUtils::SetStr(var_ref_op_desc, REF_VAR_SRC_VAR_NAME, var_op_desc->GetName()),
  173. GELOGI("Set node [%s] VAR_ATTR_VAR_IS_REF [%s]", var_ref_node->GetName().c_str(),
  174. var_op_desc->GetName().c_str()));
  175. var_ref_name_0 = var_ref_node->GetName();
  176. return var_ref_node;
  177. }
  178. bool BuildComputeGraph0(ge::ComputeGraphPtr &graph) {
  179. // graph = std::make_shared<ComputeGraph>("test");
  180. ge::NodePtr node_4d_new =
  181. NodeBuilder("Node4D_new", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  182. ge::NodePtr node_4d_to_5d_1_new = NodeBuilder("4d_to_5d_1_new", TRANSDATA)
  183. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  184. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  185. .Build(graph);
  186. ge::NodePtr node_4d_to_5d_2_new = NodeBuilder("4d_to_5d_2_new", TRANSDATA)
  187. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  188. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  189. .Build(graph);
  190. ge::GraphUtils::AddEdge(node_4d_new->GetOutDataAnchor(0), node_4d_to_5d_1_new->GetInDataAnchor(0));
  191. ge::GraphUtils::AddEdge(node_4d_new->GetOutDataAnchor(0), node_4d_to_5d_2_new->GetInDataAnchor(0));
  192. // Node4D
  193. ge::NodePtr node_4d =
  194. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  195. // NodeTrans4DTo5D
  196. ge::NodePtr node_4d_to_5d_1 = NodeBuilder("4d_to_5d_1", TRANSDATA)
  197. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  198. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  199. .Build(graph);
  200. ge::NodePtr node_4d_to_5d_2 = NodeBuilder("4d_to_5d_2", TRANSDATA)
  201. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  202. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  203. .Build(graph);
  204. // Node5D
  205. ge::NodePtr node_5d_1 =
  206. NodeBuilder("5D_1", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  207. ge::NodePtr node_5d_2 =
  208. NodeBuilder("5D_2", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  209. // add edge
  210. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_1->GetInDataAnchor(0));
  211. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_2->GetInDataAnchor(0));
  212. ge::GraphUtils::AddEdge(node_4d_to_5d_1->GetOutDataAnchor(0), node_5d_1->GetInDataAnchor(0));
  213. ge::GraphUtils::AddEdge(node_4d_to_5d_2->GetOutDataAnchor(0), node_5d_2->GetInDataAnchor(0));
  214. // Node4D
  215. ge::NodePtr node_4d_nhwc =
  216. NodeBuilder("Node4D_NHWC", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NHWC, DT_INT32).Build(graph);
  217. // NodeTrans4DTo5D
  218. ge::NodePtr node_4d_to_5d_1_nhwc = NodeBuilder("4d_to_5d_1_NHWC", TRANSDATA)
  219. .AddInputDesc({1, 2, 3, 4}, FORMAT_NHWC, DT_INT32)
  220. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  221. .Build(graph);
  222. // Node5D
  223. ge::NodePtr node_5d_1_nhwc =
  224. NodeBuilder("5D_1_NHWC", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  225. // add edge
  226. ge::GraphUtils::AddEdge(node_4d_nhwc->GetOutDataAnchor(0), node_4d_to_5d_1_nhwc->GetInDataAnchor(0));
  227. ge::GraphUtils::AddEdge(node_4d_to_5d_1_nhwc->GetOutDataAnchor(0), node_5d_1_nhwc->GetInDataAnchor(0));
  228. // Node4D
  229. ge::NodePtr node_4d_hwcn =
  230. NodeBuilder("Node4D_HWCN", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_HWCN, DT_INT32).Build(graph);
  231. // NodeTrans4DTo5D
  232. ge::NodePtr node_4d_to_5d_1_hwcn = NodeBuilder("4d_to_5d_1_HWCN", TRANSDATA)
  233. .AddInputDesc({1, 2, 3, 4}, FORMAT_HWCN, DT_INT32)
  234. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  235. .Build(graph);
  236. // Node5D
  237. ge::NodePtr node_5d_1_hwcn =
  238. NodeBuilder("5D_1_HWCN", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  239. // add edge
  240. ge::GraphUtils::AddEdge(node_4d_hwcn->GetOutDataAnchor(0), node_4d_to_5d_1_hwcn->GetInDataAnchor(0));
  241. ge::GraphUtils::AddEdge(node_4d_to_5d_1_hwcn->GetOutDataAnchor(0), node_5d_1_hwcn->GetInDataAnchor(0));
  242. ge::NodePtr node_4d_chwn =
  243. NodeBuilder("Node4D_CHWN", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_CHWN, DT_INT32).Build(graph);
  244. // NodeTrans4DTo5D
  245. ge::NodePtr node_4d_to_5d_1_chwn = NodeBuilder("4d_to_5d_1_CHWN", TRANSDATA)
  246. .AddInputDesc({1, 2, 3, 4}, FORMAT_CHWN, DT_INT32)
  247. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  248. .Build(graph);
  249. // Node5D
  250. ge::NodePtr node_5d_1_chwn =
  251. NodeBuilder("5D_1_CHWN", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  252. // add edge
  253. ge::GraphUtils::AddEdge(node_4d_chwn->GetOutDataAnchor(0), node_4d_to_5d_1_chwn->GetInDataAnchor(0));
  254. ge::GraphUtils::AddEdge(node_4d_to_5d_1_chwn->GetOutDataAnchor(0), node_5d_1_chwn->GetInDataAnchor(0));
  255. ge::NodePtr node_4d_d =
  256. NodeBuilder("Node4D_D", VARIABLE).AddOutputDesc({1}, FORMAT_CHWN, DT_INT32).Build(graph);
  257. // NodeTrans4DTo5D
  258. ge::NodePtr node_4d_to_5d_1_d = NodeBuilder("4d_to_5d_1_D", TRANSDATA)
  259. .AddInputDesc({1, 2, 3, 4}, FORMAT_CHWN, DT_INT32)
  260. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  261. .Build(graph);
  262. // Node5D
  263. ge::NodePtr node_5d_1_d =
  264. NodeBuilder("5D_1_D", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  265. ge::NodePtr node_apply_monetum = NodeBuilder("apply_monetum", APPLYMOMENTUM)
  266. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  267. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  268. .Build(graph);
  269. ge::NodePtr node_5d_to_4d_1 = NodeBuilder("5d_to_4d_1", TRANSDATA)
  270. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  271. .AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  272. .Build(graph);
  273. ge::NodePtr node_ref = CreatVariableRef(node_5d_to_4d_1, node_4d);
  274. // add edge
  275. ge::GraphUtils::AddEdge(node_4d_d->GetOutDataAnchor(0), node_4d_to_5d_1_d->GetInDataAnchor(0));
  276. ge::GraphUtils::AddEdge(node_4d_to_5d_1_d->GetOutDataAnchor(0), node_5d_1_d->GetInDataAnchor(0));
  277. if (ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0), node_5d_to_4d_1->GetInDataAnchor(0)) !=
  278. ge::SUCCESS) {
  279. /// GELOGE(FAILED, "ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0),
  280. /// node_5d_to_4d_1->GetInDataAnchor(0) ) Failed.");
  281. };
  282. ge::GraphUtils::AddEdge(node_5d_to_4d_1->GetOutDataAnchor(0), node_ref->GetInDataAnchor(0));
  283. return true;
  284. }
  285. bool BuildComputeGraph1(ge::ComputeGraphPtr &graph) {
  286. // Node4D
  287. ge::NodePtr node_4d =
  288. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  289. // NodeTrans4DTo5D
  290. ge::NodePtr node_4d_to_5d_1 = NodeBuilder("4d_to_5d_1", TRANSDATA)
  291. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  292. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  293. .Build(graph);
  294. ge::NodePtr node_4d_to_5d_2 = NodeBuilder("4d_to_5d_2", TRANSDATA)
  295. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  296. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  297. .Build(graph);
  298. // Node5D
  299. ge::NodePtr node_5d_1 =
  300. NodeBuilder("5D_1", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  301. ge::NodePtr node_5d_2 =
  302. NodeBuilder("5D_2", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  303. ge::NodePtr node_5d_to_4d_1 = NodeBuilder("5d_to_4d_1", TRANSDATA)
  304. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  305. .AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  306. .Build(graph);
  307. ge::NodePtr node_apply_monetum = NodeBuilder("apply_monetum", APPLYMOMENTUM)
  308. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  309. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  310. .Build(graph);
  311. ge::NodePtr node_ref = CreatVariableRef(node_5d_to_4d_1, node_4d);
  312. // add edge
  313. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_1->GetInDataAnchor(0));
  314. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_2->GetInDataAnchor(0));
  315. ge::GraphUtils::AddEdge(node_4d_to_5d_1->GetOutDataAnchor(0), node_5d_1->GetInDataAnchor(0));
  316. ge::GraphUtils::AddEdge(node_4d_to_5d_2->GetOutDataAnchor(0), node_5d_2->GetInDataAnchor(0));
  317. if (ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0), node_5d_to_4d_1->GetInDataAnchor(0)) !=
  318. ge::SUCCESS) {
  319. /// GELOGE(FAILED, "ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0),
  320. /// node_5d_to_4d_1->GetInDataAnchor(0) ) Failed.");
  321. };
  322. ge::GraphUtils::AddEdge(node_5d_to_4d_1->GetOutDataAnchor(0), node_ref->GetInDataAnchor(0));
  323. return true;
  324. }
  325. bool BuildComputeGraph4(ge::ComputeGraphPtr &graph) {
  326. // Node4D
  327. ge::NodePtr node_4d =
  328. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  329. // NodeTrans4DTo5D
  330. ge::NodePtr node_4d_to_5d_1 = NodeBuilder("4d_to_5d_1", TRANSDATA)
  331. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  332. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  333. .Build(graph);
  334. ge::NodePtr node_4d_to_5d_2 = NodeBuilder("4d_to_5d_2", TRANSDATA)
  335. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  336. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  337. .Build(graph);
  338. // Node5D
  339. ge::NodePtr node_5d_1 =
  340. NodeBuilder("5D_1", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  341. ge::NodePtr node_5d_2 =
  342. NodeBuilder("5D_2", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT).Build(graph);
  343. ge::NodePtr node_5d_to_4d_1 = NodeBuilder("5d_to_4d_1", TRANSDATA)
  344. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  345. .AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  346. .Build(graph);
  347. ge::NodePtr node_5d_to_4d_2 = NodeBuilder("5d_to_4d_2", TRANSDATA)
  348. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  349. .AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  350. .Build(graph);
  351. ge::NodePtr node_apply_monetum = NodeBuilder("apply_monetum", APPLYMOMENTUM)
  352. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  353. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  354. .Build(graph);
  355. ge::NodePtr node_ref = CreatVariableRef(node_5d_to_4d_1, node_4d);
  356. // add edge
  357. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_1->GetInDataAnchor(0));
  358. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_2->GetInDataAnchor(0));
  359. ge::GraphUtils::AddEdge(node_4d_to_5d_1->GetOutDataAnchor(0), node_5d_1->GetInDataAnchor(0));
  360. ge::GraphUtils::AddEdge(node_4d_to_5d_2->GetOutDataAnchor(0), node_5d_2->GetInDataAnchor(0));
  361. ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0), node_5d_to_4d_1->GetInDataAnchor(0));
  362. ge::GraphUtils::AddEdge(node_5d_to_4d_1->GetOutDataAnchor(0), node_ref->GetInDataAnchor(0));
  363. ge::GraphUtils::AddEdge(node_5d_to_4d_2->GetOutDataAnchor(0), node_ref->GetInDataAnchor(0));
  364. return true;
  365. }
  366. bool BuildComputeGraph5(ge::ComputeGraphPtr &graph) {
  367. // Node4D
  368. ge::NodePtr node_4d =
  369. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  370. return true;
  371. }
  372. bool BuildComputeGraph6(ge::ComputeGraphPtr &graph) {
  373. // Node4D
  374. ge::NodePtr node_4d =
  375. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  376. // NodeTrans4DTo5D
  377. ge::NodePtr node_4d_to_5d_1 = NodeBuilder("4d_to_5d_1", TRANSDATA)
  378. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  379. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  380. .Build(graph);
  381. ge::NodePtr node_float_to_int_1 = NodeBuilder("float_to_int_1", CAST)
  382. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  383. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  384. .Build(graph);
  385. ge::NodePtr node_4d_to_5d_2 = NodeBuilder("4d_to_5d_2", TRANSDATA)
  386. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  387. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  388. .Build(graph);
  389. ge::NodePtr node_float_to_int_2 = NodeBuilder("float_to_int_2", CAST)
  390. .AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_FLOAT)
  391. .AddOutputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32)
  392. .Build(graph);
  393. // Node5D
  394. ge::NodePtr node_5d_1 =
  395. NodeBuilder("5D_1", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32).Build(graph);
  396. ge::NodePtr node_5d_2 =
  397. NodeBuilder("5D_2", RELU).AddInputDesc({1, 2, 3, 4, 5}, FORMAT_NC1HWC0, DT_INT32).Build(graph);
  398. // add edge
  399. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_1->GetInDataAnchor(0));
  400. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_5d_2->GetInDataAnchor(0));
  401. ge::GraphUtils::AddEdge(node_4d_to_5d_1->GetOutDataAnchor(0), node_float_to_int_1->GetInDataAnchor(0));
  402. ge::GraphUtils::AddEdge(node_4d_to_5d_2->GetOutDataAnchor(0), node_float_to_int_2->GetInDataAnchor(0));
  403. ge::GraphUtils::AddEdge(node_float_to_int_1->GetOutDataAnchor(0), node_5d_1->GetInDataAnchor(0));
  404. ge::GraphUtils::AddEdge(node_float_to_int_2->GetOutDataAnchor(0), node_5d_2->GetInDataAnchor(0));
  405. return true;
  406. }
  407. } // namespace
  408. bool BuildComputeGraph7(ge::ComputeGraphPtr &graph) {
  409. // Node4D
  410. ge::NodePtr node_4d =
  411. NodeBuilder("Node4D", VARIABLE).AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32).Build(graph);
  412. // NodeTrans4DTo5D
  413. ge::NodePtr node_4d_to_4d_1 = NodeBuilder("4d_to_4d_1", TRANSDATA)
  414. .AddInputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  415. .AddOutputDesc({1, 2, 3, 4}, FORMAT_NCHW, DT_INT32)
  416. .Build(graph);
  417. // Node5D
  418. ge::NodePtr node_4d_1 = NodeBuilder("4D_1", RELU).AddInputDesc({1, 2, 3, 4}, FORMAT_NC1HWC0, DT_INT32).Build(graph);
  419. // add edge
  420. ge::GraphUtils::AddEdge(node_4d->GetOutDataAnchor(0), node_4d_to_4d_1->GetInDataAnchor(0));
  421. ge::GraphUtils::AddEdge(node_4d_to_4d_1->GetOutDataAnchor(0), node_4d_1->GetInDataAnchor(0));
  422. return true;
  423. }
  424. class VariableOpPassSimulator {
  425. public:
  426. bool DoTest0() {
  427. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  428. const std::string var_name = "Node4D";
  429. uint64_t session_id = 0;
  430. uint32_t device_id = 0;
  431. uint64_t job_id = 0;
  432. uint32_t session_version = 0;
  433. std::vector<int64_t> dims(4, 20);
  434. ge::GeShape shape(dims);
  435. MemManager::Instance().Initialize(std::vector<rtMemType_t>({RT_MEMORY_HBM}));
  436. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  437. BuildComputeGraph0(compute_graph);
  438. std::vector<std::string> var_names = {"Node4D_new", "Node4D", "Node4D_NHWC",
  439. "Node4D_HWCN", "Node4D_CHWN", "Node4D_D"};
  440. for (auto name : var_names) {
  441. auto var_node = compute_graph->FindNode(name);
  442. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  443. uint8_t *dev_ptr = nullptr;
  444. ge::VarManager::Instance(session_id)->AssignVarMem(name, var_tensor_desc, RT_MEMORY_HBM);
  445. ge::VarManager::Instance(session_id)->SetVarAddr(name, var_tensor_desc, dev_ptr, RT_MEMORY_HBM);
  446. }
  447. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  448. compute_graph->InferShapeInNeed();
  449. graph_node->SetComputeGraph(compute_graph);
  450. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  451. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  452. graph_node->SetGraph(tmp_graph_ptr);
  453. VarAccelerateCtrl ctrl;
  454. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  455. ge::formats::FormatTransferNchwNc1hwc0 ClassObj;
  456. VariableOpPass pass(&ctrl);
  457. pass.Run(compute_graph);
  458. MemManager::Instance().Finalize();
  459. return CheckTest0(compute_graph);
  460. }
  461. bool DoTest1() {
  462. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  463. const std::string var_name = "Node4D";
  464. uint64_t session_id = 0;
  465. uint32_t device_id = 0;
  466. uint64_t job_id = 0;
  467. uint32_t session_version = 0;
  468. std::vector<int64_t> dims(4, 20);
  469. ge::GeShape shape(dims);
  470. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  471. BuildComputeGraph1(compute_graph);
  472. auto var_node = compute_graph->FindNode(var_name);
  473. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  474. uint8_t *dev_ptr = nullptr;
  475. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  476. compute_graph->InferShapeInNeed();
  477. graph_node->SetComputeGraph(compute_graph);
  478. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  479. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  480. graph_node->SetGraph(tmp_graph_ptr);
  481. VarAccelerateCtrl ctrl;
  482. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  483. VariableOpPass pass(&ctrl);
  484. pass.Run(compute_graph);
  485. return CheckTest1(compute_graph);
  486. }
  487. bool DoTest2() {
  488. VarAccelerateCtrl ctrl;
  489. VariableOpPass pass(&ctrl);
  490. return pass.Run(nullptr) == ge::INTERNAL_ERROR;
  491. }
  492. bool DoTest3() {
  493. std::vector<rtMemType_t> mem_type;
  494. std::map<std::string, std::string> empty_options;
  495. mem_type.push_back(RT_MEMORY_HBM);
  496. MemManager::Instance().Initialize(mem_type);
  497. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  498. std::vector<std::string> var_names = {"Node4D", "Node4D_NHWC", "Node4D_HWCN", "Node4D_CHWN", "Node4D_D"};
  499. std::vector<ge::GeTensorDesc> tensor_descs;
  500. uint64_t session_id = 0;
  501. uint32_t device_id = 0;
  502. uint64_t job_id = 0;
  503. uint32_t session_version = 0;
  504. compute_graph->SetSessionID(session_id);
  505. std::vector<int64_t> dims(4, 20);
  506. ge::GeShape shape(dims);
  507. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  508. BuildComputeGraph0(compute_graph);
  509. for (auto var_name : var_names) {
  510. auto var_node = compute_graph->FindNode(var_name);
  511. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  512. uint8_t *dev_ptr = nullptr;
  513. ge::VarManager::Instance(session_id)->AssignVarMem(var_name, var_tensor_desc, RT_MEMORY_HBM);
  514. ge::VarManager::Instance(session_id)->SetVarAddr(var_name, var_tensor_desc, dev_ptr, RT_MEMORY_HBM);
  515. }
  516. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  517. compute_graph->InferShapeInNeed();
  518. graph_node->SetComputeGraph(compute_graph);
  519. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  520. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  521. graph_node->SetGraph(tmp_graph_ptr);
  522. VarAccelerateCtrl ctrl;
  523. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  524. VariableOpPass pass(&ctrl);
  525. auto ret = pass.Run(compute_graph);
  526. MemManager::Instance().Finalize();
  527. return ret == GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  528. }
  529. bool DoTest4() {
  530. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  531. const std::string var_name = "Node4D";
  532. uint64_t session_id = 0;
  533. uint32_t device_id = 0;
  534. uint64_t job_id = 0;
  535. uint32_t session_version = 0;
  536. std::vector<int64_t> dims(4, 20);
  537. ge::GeShape shape(dims);
  538. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  539. BuildComputeGraph4(compute_graph);
  540. auto var_node = compute_graph->FindNode(var_name);
  541. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  542. uint8_t *dev_ptr = nullptr;
  543. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  544. compute_graph->InferShapeInNeed();
  545. graph_node->SetComputeGraph(compute_graph);
  546. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  547. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  548. graph_node->SetGraph(tmp_graph_ptr);
  549. VarAccelerateCtrl ctrl;
  550. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  551. VariableOpPass pass(&ctrl);
  552. auto ret = pass.Run(compute_graph);
  553. return ret == ge::SUCCESS;
  554. }
  555. bool DoTest5() {
  556. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  557. BuildComputeGraph5(compute_graph);
  558. const std::string var_name = "Node4D";
  559. uint64_t session_id = 0;
  560. uint32_t device_id = 0;
  561. uint64_t job_id = 0;
  562. uint32_t session_version = 0;
  563. std::vector<int64_t> dims(4, 20);
  564. ge::GeShape shape(dims);
  565. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  566. BuildComputeGraph4(compute_graph);
  567. auto var_node = compute_graph->FindNode(var_name);
  568. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  569. uint8_t *dev_ptr = nullptr;
  570. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  571. compute_graph->InferShapeInNeed();
  572. graph_node->SetComputeGraph(compute_graph);
  573. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  574. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  575. graph_node->SetGraph(tmp_graph_ptr);
  576. VarAccelerateCtrl ctrl;
  577. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  578. VariableOpPass pass(&ctrl);
  579. auto ret = pass.Run(compute_graph);
  580. return ret == ge::SUCCESS;
  581. }
  582. bool DoTest6() {
  583. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  584. const std::string var_name = "Node4D";
  585. uint64_t session_id = 0;
  586. uint32_t device_id = 0;
  587. uint64_t job_id = 0;
  588. uint32_t session_version = 0;
  589. std::vector<int64_t> dims(4, 20);
  590. ge::GeShape shape(dims);
  591. MemManager::Instance().Initialize(std::vector<rtMemType_t>({RT_MEMORY_HBM}));
  592. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  593. BuildComputeGraph6(compute_graph);
  594. auto var_node = compute_graph->FindNode(var_name);
  595. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  596. uint8_t *dev_ptr = nullptr;
  597. ge::VarManager::Instance(session_id)->AssignVarMem(var_name, var_tensor_desc, RT_MEMORY_HBM);
  598. ge::VarManager::Instance(session_id)->SetVarAddr(var_name, var_tensor_desc, dev_ptr, RT_MEMORY_HBM);
  599. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  600. compute_graph->InferShapeInNeed();
  601. graph_node->SetComputeGraph(compute_graph);
  602. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  603. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  604. graph_node->SetGraph(tmp_graph_ptr);
  605. VarAccelerateCtrl ctrl;
  606. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  607. ge::formats::FormatTransferNchwNc1hwc0 ClassObj;
  608. VariableOpPass pass(&ctrl);
  609. auto ret = pass.Run(compute_graph);
  610. MemManager::Instance().Finalize();
  611. return CheckTest6(compute_graph);
  612. }
  613. bool DoTest7() {
  614. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  615. const std::string var_name = "Node4D";
  616. uint64_t session_id = 0;
  617. uint32_t device_id = 0;
  618. uint64_t job_id = 0;
  619. uint32_t session_version = 0;
  620. std::vector<int64_t> dims(4, 20);
  621. ge::GeShape shape(dims);
  622. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  623. BuildComputeGraph7(compute_graph);
  624. auto var_node = compute_graph->FindNode(var_name);
  625. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  626. uint8_t *dev_ptr = nullptr;
  627. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  628. compute_graph->InferShapeInNeed();
  629. graph_node->SetComputeGraph(compute_graph);
  630. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  631. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  632. graph_node->SetGraph(tmp_graph_ptr);
  633. VarAccelerateCtrl ctrl;
  634. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  635. VariableOpPass pass(&ctrl);
  636. auto ret = pass.Run(compute_graph);
  637. return CheckTest7(compute_graph);
  638. }
  639. bool DoTest8() {
  640. ge::ComputeGraphPtr compute_graph = std::make_shared<ComputeGraph>("0");
  641. const std::string var_name = "Node4D";
  642. uint64_t session_id = 0;
  643. uint32_t device_id = 0;
  644. uint64_t job_id = 0;
  645. uint32_t session_version = 0;
  646. std::vector<int64_t> dims(4, 20);
  647. ge::GeShape shape(dims);
  648. VarManager::Instance(session_id)->Init(session_version, session_id, device_id, job_id);
  649. BuildComputeGraph0(compute_graph);
  650. auto var_node = compute_graph->FindNode(var_name);
  651. auto var_tensor_desc = var_node->GetOpDesc()->GetOutputDesc(0);
  652. uint8_t *dev_ptr = nullptr;
  653. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  654. compute_graph->InferShapeInNeed();
  655. graph_node->SetComputeGraph(compute_graph);
  656. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  657. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  658. graph_node->SetGraph(tmp_graph_ptr);
  659. VarAccelerateCtrl ctrl;
  660. ctrl.AddGraph(graph_node->GetGraphId(), compute_graph);
  661. VariableOpPass pass(&ctrl);
  662. pass.Run(compute_graph);
  663. return CheckTest8(compute_graph);
  664. }
  665. private:
  666. bool CheckTest0(const ge::ComputeGraphPtr compute_graph) {
  667. const auto &variable_node = compute_graph->FindNode("Node4D");
  668. auto variable_node_format = variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  669. auto variable_node_data_type = variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  670. auto variable_node_shape = variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  671. if (variable_node_format != FORMAT_NC1HWC0 || variable_node_data_type != DT_FLOAT ||
  672. variable_node_shape.size() != 5) {
  673. std::cout << "var format not changed !" << std::endl;
  674. return false;
  675. }
  676. const auto &variable_ref_node = compute_graph->FindNode(var_ref_name_0);
  677. GELOGD("var_ref_name_0 is %s", var_ref_name_0.c_str());
  678. auto variable_ref_node_format = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetFormat();
  679. auto variable_ref_node_data_type = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetDataType();
  680. auto variable_ref_node_shape = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetShape().GetDims();
  681. if (variable_ref_node_format != FORMAT_NC1HWC0 || variable_ref_node_data_type != DT_FLOAT ||
  682. variable_ref_node_shape.size() != 5) {
  683. GELOGI("wanted data format is (%d,%d,%u)", FORMAT_NC1HWC0, DT_FLOAT, 5);
  684. GELOGI("variable_ref_node_format is (%d,%d,%u)", variable_ref_node_format, variable_ref_node_data_type,
  685. variable_ref_node_shape.size());
  686. std::cout << "var ref format not changed !" << std::endl;
  687. return false;
  688. }
  689. ge::NodePtr trans_node = compute_graph->FindNode("4d_to_5d_1");
  690. if (trans_node != nullptr) {
  691. std::cout << "4d_to_5d_1 not empty !" << std::endl;
  692. return false;
  693. }
  694. trans_node = compute_graph->FindNode("4d_to_5d_2");
  695. if (trans_node != nullptr) {
  696. std::cout << "4d_to_5d_2 not empty !" << std::endl;
  697. return false;
  698. }
  699. trans_node = compute_graph->FindNode("5d_to_4d_1");
  700. if (trans_node != nullptr) {
  701. std::cout << "5d_to_4d_1 not empty !" << std::endl;
  702. return false;
  703. }
  704. trans_node = compute_graph->FindNode("4d_to_5d_1_new");
  705. if (trans_node == nullptr) {
  706. std::cout << "4d_to_5d_1_new is empty !" << std::endl;
  707. return false;
  708. }
  709. auto new_variable_node = compute_graph->FindNode("Node4D_new");
  710. auto new_variable_node_format = new_variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  711. auto new_variable_node_data_type = new_variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  712. auto new_variable_node_shape = new_variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  713. if (new_variable_node_format != FORMAT_NCHW || new_variable_node_data_type != DT_INT32 ||
  714. new_variable_node_shape.size() != 4) {
  715. std::cout << "Node4D_new format Changed ! wanted data format is ( " << FORMAT_NC1HWC0 << ", " << DT_INT32
  716. << ", 4) " << std::endl;
  717. std::cout << "current is ( " << new_variable_node_format << ", " << new_variable_node_data_type << ", "
  718. << new_variable_node_shape.size() << ")" << std::endl;
  719. return false;
  720. }
  721. return true;
  722. };
  723. bool CheckTest1(const ge::ComputeGraphPtr compute_graph) {
  724. const auto &variable_node = compute_graph->FindNode("Node4D");
  725. auto variable_node_format = variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  726. auto variable_node_data_type = variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  727. auto variable_node_shape = variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  728. if (variable_node_format != FORMAT_NCHW || variable_node_data_type != DT_INT32 || variable_node_shape.size() != 4) {
  729. std::cout << "var format changed !" << std::endl;
  730. return false;
  731. }
  732. const auto &variable_ref_node = compute_graph->FindNode(var_ref_name_0);
  733. GELOGD("var_ref_name_0 is %s", var_ref_name_0.c_str());
  734. auto variable_ref_node_format = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetFormat();
  735. auto variable_ref_node_data_type = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetDataType();
  736. auto variable_ref_node_shape = variable_ref_node->GetOpDesc()->GetInputDesc(0).GetShape().GetDims();
  737. if (variable_ref_node_format != FORMAT_NCHW || variable_ref_node_data_type != DT_INT32 ||
  738. variable_ref_node_shape.size() != 4) {
  739. GELOGI("wanted data format is (%d,%d,%u)", FORMAT_NCHW, DT_INT32, 4);
  740. GELOGI("variable_ref_node_format is (%d,%d,%u)", variable_ref_node_format, variable_ref_node_data_type,
  741. variable_ref_node_shape.size());
  742. std::cout << "var ref format not changed !" << std::endl;
  743. return false;
  744. }
  745. ge::NodePtr trans_node = compute_graph->FindNode("4d_to_5d_1");
  746. if (trans_node == nullptr) {
  747. std::cout << "4d_to_5d_1 empty !" << std::endl;
  748. return false;
  749. }
  750. trans_node = compute_graph->FindNode("4d_to_5d_2");
  751. if (trans_node == nullptr) {
  752. std::cout << "4d_to_5d_2 empty !" << std::endl;
  753. return false;
  754. }
  755. trans_node = compute_graph->FindNode("5d_to_4d_1");
  756. if (trans_node == nullptr) {
  757. std::cout << "5d_to_4d_1 not empty !" << std::endl;
  758. return false;
  759. }
  760. return true;
  761. };
  762. bool CheckTest6(const ge::ComputeGraphPtr compute_graph) {
  763. const auto &variable_node = compute_graph->FindNode("Node4D");
  764. auto variable_node_format = variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  765. auto variable_node_data_type = variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  766. auto variable_node_shape = variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  767. if (variable_node_format != FORMAT_NC1HWC0 || variable_node_data_type != DT_INT32 ||
  768. variable_node_shape.size() != 5) {
  769. std::cout << "var format not changed !" << std::endl;
  770. return false;
  771. }
  772. ge::NodePtr trans_node = compute_graph->FindNode("4d_to_5d_1");
  773. if (trans_node != nullptr) {
  774. std::cout << "4d_to_5d_1 not empty !" << std::endl;
  775. return false;
  776. }
  777. trans_node = compute_graph->FindNode("4d_to_5d_2");
  778. if (trans_node != nullptr) {
  779. std::cout << "4d_to_5d_2 not empty !" << std::endl;
  780. return false;
  781. }
  782. trans_node = compute_graph->FindNode("float_to_int_1");
  783. if (trans_node != nullptr) {
  784. std::cout << "float_to_int_1 not empty !" << std::endl;
  785. return false;
  786. }
  787. trans_node = compute_graph->FindNode("float_to_int_2");
  788. if (trans_node != nullptr) {
  789. std::cout << "float_to_int_1 not empty !" << std::endl;
  790. return false;
  791. }
  792. return true;
  793. };
  794. bool CheckTest7(const ge::ComputeGraphPtr compute_graph) {
  795. const auto &variable_node = compute_graph->FindNode("Node4D");
  796. auto variable_node_format = variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  797. auto variable_node_data_type = variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  798. auto variable_node_shape = variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  799. if (variable_node_format != FORMAT_NC1HWC0 || variable_node_data_type != DT_INT32 ||
  800. variable_node_shape.size() != 5) {
  801. std::cout << "var format not changed !" << std::endl;
  802. return false;
  803. }
  804. ge::NodePtr trans_node = compute_graph->FindNode("4d_to_4d_1");
  805. if (trans_node != nullptr) {
  806. std::cout << "4d_to_5d_1 not empty !" << std::endl;
  807. return false;
  808. }
  809. return true;
  810. };
  811. bool CheckTest8(const ge::ComputeGraphPtr compute_graph) {
  812. const auto &variable_node = compute_graph->FindNode("Node4D");
  813. auto variable_node_format = variable_node->GetOpDesc()->GetOutputDesc(0).GetFormat();
  814. auto variable_node_data_type = variable_node->GetOpDesc()->GetOutputDesc(0).GetDataType();
  815. auto variable_node_shape = variable_node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims();
  816. return true;
  817. };
  818. };
  819. TEST_F(UtestVariableOpPassUnit, test_trans_data_remove) {
  820. VariableOpPassSimulator varibale_op_pass_simulator;
  821. bool result = varibale_op_pass_simulator.DoTest0();
  822. EXPECT_EQ(result, true);
  823. }
  824. TEST_F(UtestVariableOpPassUnit, test_variable_ref) {
  825. VariableOpPassSimulator varibale_op_pass_simulator;
  826. bool result = varibale_op_pass_simulator.DoTest1();
  827. EXPECT_EQ(result, true);
  828. }
  829. TEST_F(UtestVariableOpPassUnit, test_null_graph) {
  830. VariableOpPassSimulator varibale_op_pass_simulator;
  831. bool result = varibale_op_pass_simulator.DoTest2();
  832. EXPECT_EQ(result, true);
  833. }
  834. TEST_F(UtestVariableOpPassUnit, test_covarage_trans_var_data) {
  835. VariableOpPassSimulator varibale_op_pass_simulator;
  836. bool result = varibale_op_pass_simulator.DoTest3();
  837. EXPECT_EQ(result, false);
  838. }
  839. TEST_F(UtestVariableOpPassUnit, test_illegally_ref) {
  840. VariableOpPassSimulator varibale_op_pass_simulator;
  841. bool result = varibale_op_pass_simulator.DoTest4();
  842. EXPECT_EQ(result, true);
  843. }
  844. TEST_F(UtestVariableOpPassUnit, test_single_node) {
  845. VariableOpPassSimulator varibale_op_pass_simulator;
  846. bool result = varibale_op_pass_simulator.DoTest5();
  847. EXPECT_EQ(result, true);
  848. }
  849. TEST_F(UtestVariableOpPassUnit, test_un_mathed) {
  850. VariableOpPassSimulator varibale_op_pass_simulator;
  851. bool result = varibale_op_pass_simulator.DoTest6();
  852. EXPECT_EQ(result, true);
  853. }
  854. TEST_F(UtestVariableOpPassUnit, test_same_op) {
  855. VariableOpPassSimulator varibale_op_pass_simulator;
  856. bool result = varibale_op_pass_simulator.DoTest7();
  857. EXPECT_EQ(true, true);
  858. }
  859. TEST_F(UtestVariableOpPassUnit, test_error_return) {
  860. VariableOpPassSimulator varibale_op_pass_simulator;
  861. bool result = varibale_op_pass_simulator.DoTest8();
  862. EXPECT_EQ(true, true);
  863. }
  864. TEST_F(UtestVariableOpPassUnit, reshape) {
  865. // init
  866. MemManager::Instance().Initialize(std::vector<rtMemType_t>({RT_MEMORY_HBM}));
  867. VarManager::Instance(0)->Init(0, 0, 0, 0);
  868. auto graph = BuildGraph2();
  869. graph->SetSessionID(0);
  870. auto var1 = graph->FindNode("var1");
  871. VarManager::Instance(0)->AssignVarMem(var1->GetName(), var1->GetOpDesc()->GetOutputDesc(0), RT_MEMORY_HBM);
  872. uint8_t *dev_ptr = nullptr;
  873. VarManager::Instance(0)->SetVarAddr(var1->GetName(), var1->GetOpDesc()->GetOutputDesc(0), dev_ptr, RT_MEMORY_HBM);
  874. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  875. graph->InferShapeInNeed();
  876. graph_node->SetComputeGraph(graph);
  877. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
  878. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  879. graph_node->SetGraph(tmp_graph_ptr);
  880. VarAccelerateCtrl ctrl;
  881. ctrl.AddGraph(graph_node->GetGraphId(), graph);
  882. VariableOpPass pass(&ctrl);
  883. EXPECT_EQ(pass.Run(graph), ge::SUCCESS);
  884. MemManager::Instance().Finalize();
  885. EXPECT_EQ(var1->GetOutNodes().size(), 1);
  886. EXPECT_EQ(var1->GetOutDataNodes().at(0)->GetName(), "conv1");
  887. EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetFormat(), FORMAT_HWCN);
  888. EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims(), std::vector<int64_t>({8, 8, 3, 2}));
  889. }
  890. TEST_F(UtestVariableOpPassUnit, reformat) {
  891. // init
  892. MemManager::Instance().Initialize(std::vector<rtMemType_t>({RT_MEMORY_HBM}));
  893. VarManager::Instance(0)->Init(0, 0, 0, 0);
  894. auto graph = BuildGraph3();
  895. graph->SetSessionID(0);
  896. auto var1 = graph->FindNode("var1");
  897. VarManager::Instance(0)->AssignVarMem(var1->GetName(), var1->GetOpDesc()->GetOutputDesc(0), RT_MEMORY_HBM);
  898. uint8_t *dev_ptr = nullptr;
  899. VarManager::Instance(0)->SetVarAddr(var1->GetName(), var1->GetOpDesc()->GetOutputDesc(0), dev_ptr, RT_MEMORY_HBM);
  900. ge::GraphNodePtr graph_node = make_shared<GraphNode>(0);
  901. graph->InferShapeInNeed();
  902. graph_node->SetComputeGraph(graph);
  903. auto tmp_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
  904. auto tmp_graph_ptr = std::make_shared<Graph>(tmp_graph);
  905. graph_node->SetGraph(tmp_graph_ptr);
  906. VarAccelerateCtrl ctrl;
  907. ctrl.AddGraph(graph_node->GetGraphId(), graph);
  908. VariableOpPass pass(&ctrl);
  909. EXPECT_EQ(pass.Run(graph), ge::SUCCESS);
  910. MemManager::Instance().Finalize();
  911. EXPECT_EQ(var1->GetOutNodes().size(), 1);
  912. EXPECT_EQ(var1->GetOutDataNodes().at(0)->GetName(), "conv1");
  913. EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetFormat(), FORMAT_ND);
  914. EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims(), std::vector<int64_t>({8, 8, 3, 2}));
  915. }
  916. TEST_F(UtestVariableOpPassUnit, invalid_src_shape2) {
  917. formats::FormatTransferNchwNc1hwc0 t1;
  918. formats::FormatTransferNhwcNc1hwc0 t2;
  919. formats::TransArgs args = formats::TransArgs();
  920. formats::TransResult ret;
  921. t2.TransFormat(args, ret);
  922. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示