You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_hybrid_unittest.cc 37 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <gmock/gmock.h>
  18. #include <vector>
  19. #include "runtime/rt.h"
  20. #define protected public
  21. #define private public
  22. #include "graph/utils/node_utils.h"
  23. #include "hybrid/model/hybrid_model_builder.h"
  24. #include "hybrid/model/hybrid_model.h"
  25. #include "hybrid/node_executor/node_executor.h"
  26. #include "common/model/ge_model.h"
  27. #include "common/model/ge_root_model.h"
  28. #include "hybrid/node_executor/aicore/aicore_op_task.h"
  29. #include "framework/common/taskdown_common.h"
  30. #include "framework/common/debug/log.h"
  31. #include "graph/ge_context.h"
  32. #include "hybrid/executor/hybrid_execution_context.h"
  33. #include "hybrid/executor/hybrid_model_executor.h"
  34. #include "hybrid/node_executor/aicore/aicore_task_builder.h"
  35. #include "hybrid/node_executor/aicore/aicore_node_executor.h"
  36. #include "graph/load/model_manager/tbe_handle_store.h"
  37. #include "graph/manager/graph_mem_allocator.h"
  38. #include "hybrid/common/npu_memory_allocator.h"
  39. #include "graph/types.h"
  40. #include "graph/utils/tensor_utils.h"
  41. #include "graph/testcase/ge_graph/graph_builder_utils.h"
  42. #include "single_op/task/build_task_utils.h"
  43. #include "graph/op_desc_impl.h"
  44. using namespace std;
  45. namespace ge {
  46. using namespace hybrid;
  47. class UtestGeHybrid : public testing::Test {
  48. protected:
  49. void SetUp() {}
  50. void TearDown() {
  51. NpuMemoryAllocator::allocators_.clear();
  52. }
  53. };
  54. static ge::OpDescPtr CreateOpDesc(string name = "", string type = "", int in_num = 0, int out_num = 0) {
  55. auto op_desc = std::make_shared<ge::OpDesc>(name, type);
  56. op_desc->SetStreamId(0);
  57. static int32_t index = 0;
  58. op_desc->SetId(index++);
  59. GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64);
  60. TensorUtils::SetSize(tensor, 64);
  61. vector<int64_t> input_offset;
  62. for (int i = 0; i < in_num; ++i) {
  63. op_desc->AddInputDesc(tensor);
  64. input_offset.emplace_back(index * 64 + i * 64);
  65. }
  66. op_desc->SetInputOffset(input_offset);
  67. vector<int64_t> output_offset;
  68. for (int i = 0; i < out_num; ++i) {
  69. op_desc->AddOutputDesc(tensor);
  70. output_offset.emplace_back(index * 64 + in_num * 64 + i * 64);
  71. }
  72. op_desc->SetOutputOffset(output_offset);
  73. op_desc->SetWorkspace({});
  74. op_desc->SetWorkspaceBytes({});
  75. ge::AttrUtils::SetStr(op_desc, ge::TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF_AIVEC");
  76. bool support_dynamic = true;
  77. ge::AttrUtils::GetBool(op_desc, "support_dynamicshape", support_dynamic);
  78. return op_desc;
  79. }
  80. TEST_F(UtestGeHybrid, aicore_op_task_init_success) {
  81. // build aicore task
  82. auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
  83. domi::TaskDef task_def;
  84. task_def.set_type(RT_MODEL_TASK_ALL_KERNEL);
  85. domi::KernelDefWithHandle *kernel_with_handle = task_def.mutable_kernel_with_handle();
  86. kernel_with_handle->set_original_kernel_key("");
  87. kernel_with_handle->set_node_info("");
  88. kernel_with_handle->set_block_dim(32);
  89. kernel_with_handle->set_args_size(64);
  90. string args(64, '1');
  91. kernel_with_handle->set_args(args.data(), 64);
  92. domi::KernelContext *context = kernel_with_handle->mutable_context();
  93. context->set_op_index(1);
  94. context->set_kernel_type(2); // ccKernelType::TE
  95. uint16_t args_offset[9] = {0};
  96. context->set_args_offset(args_offset, 9 * sizeof(uint16_t));
  97. OpDescPtr op_desc = CreateOpDesc("Add", "Add");
  98. std::vector<char> kernelBin;
  99. TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
  100. op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
  101. std::string kernel_name("kernel/Add");
  102. AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name);
  103. ASSERT_EQ(aicore_task->Init(*op_desc.get(), task_def), SUCCESS);
  104. rtStream_t stream = nullptr;
  105. rtStreamCreate(&stream, 0);
  106. ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS);
  107. char *handle = "";
  108. aicore_task->handle_ = handle;
  109. aicore_task->tiling_key_ = 1;
  110. ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS);
  111. }
  112. TEST_F(UtestGeHybrid, aicore_op_task_init_success2) {
  113. // build aicore task
  114. auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
  115. aicore_task->is_single_op_ = true;
  116. domi::TaskDef task_def;
  117. task_def.set_type(RT_MODEL_TASK_KERNEL);
  118. domi::KernelDef *kernel = task_def.mutable_kernel();
  119. kernel->set_block_dim(32);
  120. kernel->set_args_size(64);
  121. string args(64, '1');
  122. kernel->set_args(args.data(), 64);
  123. domi::KernelContext *context = kernel->mutable_context();
  124. context->set_op_index(1);
  125. context->set_kernel_type(2); // ccKernelType::TE
  126. uint16_t args_offset[9] = {0};
  127. context->set_args_offset(args_offset, 9 * sizeof(uint16_t));
  128. OpDescPtr op_desc = CreateOpDesc("Add", "Add");
  129. std::vector<char> kernelBin;
  130. TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
  131. op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
  132. std::string kernel_name("kernel/Add");
  133. AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name);
  134. ASSERT_EQ(aicore_task->InitWithTaskDef(*op_desc.get(), task_def), SUCCESS);
  135. rtStream_t stream = nullptr;
  136. rtStreamCreate(&stream, 0);
  137. ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS);
  138. char *handle = "";
  139. aicore_task->handle_ = handle;
  140. aicore_task->tiling_key_ = 1;
  141. ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS);
  142. }
  143. TEST_F(UtestGeHybrid, task_update_tiling_info) {
  144. auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
  145. auto graph = make_shared<ComputeGraph>("graph");
  146. OpDescPtr op_desc = CreateOpDesc("Add", "Add");
  147. ge::AttrUtils::SetStr(op_desc, "compile_info_key", "key");
  148. ge::AttrUtils::SetStr(op_desc, "compile_info_json", "json");
  149. ge::AttrUtils::SetBool(op_desc, "support_dynamicshape", true);
  150. ge::AttrUtils::SetInt(op_desc, "op_para_size", 1);
  151. auto node = graph->AddNode(op_desc);
  152. std::unique_ptr<NodeItem> node_item;
  153. NodeItem::Create(node, node_item);
  154. node_item->input_start = 0;
  155. node_item->output_start = 0;
  156. GraphExecutionContext execution_context;
  157. GraphItem graph_item;
  158. SubgraphContext subgraph_context(&graph_item, &execution_context);
  159. ASSERT_EQ(subgraph_context.Init(), SUCCESS);
  160. auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
  161. ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS);
  162. ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state->GetTaskContext()), SUCCESS);
  163. }
  164. TEST_F(UtestGeHybrid, index_taskdefs_failed) {
  165. // build aicore task
  166. domi::ModelTaskDef model_task_def;
  167. std::shared_ptr<domi::ModelTaskDef> model_task_def_ptr = make_shared<domi::ModelTaskDef>(model_task_def);
  168. domi::TaskDef *task_def = model_task_def_ptr->add_task();
  169. GeModelPtr ge_model = make_shared<GeModel>();
  170. ge_model->SetModelTaskDef(model_task_def_ptr);
  171. auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
  172. task_def->set_type(RT_MODEL_TASK_ALL_KERNEL);
  173. domi::KernelDefWithHandle *kernel_with_handle = task_def->mutable_kernel_with_handle();
  174. kernel_with_handle->set_original_kernel_key("");
  175. kernel_with_handle->set_node_info("");
  176. kernel_with_handle->set_block_dim(32);
  177. kernel_with_handle->set_args_size(64);
  178. string args(64, '1');
  179. kernel_with_handle->set_args(args.data(), 64);
  180. domi::KernelContext *context = kernel_with_handle->mutable_context();
  181. context->set_op_index(1);
  182. context->set_kernel_type(2); // ccKernelType::TE
  183. uint16_t args_offset[9] = {0};
  184. context->set_args_offset(args_offset, 9 * sizeof(uint16_t));
  185. OpDescPtr op_desc = CreateOpDesc("Add", "Add");
  186. std::vector<char> kernelBin;
  187. TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
  188. op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
  189. std::string kernel_name("kernel/Add");
  190. AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name);
  191. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  192. GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
  193. ge_root_model->SetModelName("test_name");
  194. HybridModel hybrid_model(ge_root_model);
  195. HybridModelBuilder hybrid_model_builder(hybrid_model);
  196. ASSERT_EQ(hybrid_model_builder.Build(), INTERNAL_ERROR);
  197. ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR);
  198. }
  199. TEST_F(UtestGeHybrid, parse_force_infershape_nodes) {
  200. const char *const kForceInfershape = "_force_infershape_when_running";
  201. auto graph = make_shared<ComputeGraph>("graph");
  202. OpDescPtr op_desc = CreateOpDesc("Conv2D", "Conv2D");
  203. ge::AttrUtils::SetBool(op_desc, kForceInfershape, true);
  204. auto node = graph->AddNode(op_desc);
  205. std::unique_ptr<NodeItem> new_node;
  206. NodeItem::Create(node, new_node);
  207. GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
  208. HybridModel hybrid_model(ge_root_model);
  209. HybridModelBuilder hybrid_model_builder(hybrid_model);
  210. ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS);
  211. }
  212. static ComputeGraphPtr BuildDataDirectConnectGraph() {
  213. const char *kRefIndex = "_parent_node_index";
  214. ge::ut::GraphBuilder builder("subgraph");
  215. auto data = builder.AddNode("Data", "Data", 1, 1);
  216. auto netoutput = builder.AddNode("NetOutput", "NetOutput", 1, 1);
  217. (void)AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), kRefIndex, 0);
  218. builder.AddDataEdge(data, 0, netoutput, 0);
  219. return builder.GetGraph();
  220. }
  221. TEST_F(UtestGeHybrid, data_direct_connect) {
  222. std::unique_ptr<NodeItem> node_item;
  223. auto root_graph = make_shared<ComputeGraph>("root_graph");
  224. OpDescPtr op_desc = CreateOpDesc("PartitionedCall", "PartitionedCall");
  225. auto node = root_graph->AddNode(op_desc);
  226. node->SetOwnerComputeGraph(root_graph);
  227. auto sub_graph = BuildDataDirectConnectGraph();
  228. sub_graph->SetParentGraph(root_graph);
  229. sub_graph->SetParentNode(node);
  230. node->GetOpDesc()->AddSubgraphName("subgraph");
  231. node->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph");
  232. root_graph->AddSubgraph("subgraph", sub_graph);
  233. std::unique_ptr<NodeItem> new_node;
  234. NodeItem::Create(node, new_node);
  235. GeRootModelPtr ge_root_model = make_shared<GeRootModel>(root_graph);
  236. HybridModel hybrid_model(ge_root_model);
  237. HybridModelBuilder hybrid_model_builder(hybrid_model);
  238. auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get(), sub_graph);
  239. ASSERT_EQ(ret, SUCCESS);
  240. }
  241. TEST_F(UtestGeHybrid, index_taskdefs_success) {
  242. // build aicore task
  243. domi::ModelTaskDef model_task_def;
  244. std::shared_ptr<domi::ModelTaskDef> model_task_def_ptr = make_shared<domi::ModelTaskDef>(model_task_def);
  245. domi::TaskDef *task_def = model_task_def_ptr->add_task();
  246. GeModelPtr ge_model = make_shared<GeModel>();
  247. ge_model->SetModelTaskDef(model_task_def_ptr);
  248. auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
  249. task_def->set_type(RT_MODEL_TASK_ALL_KERNEL);
  250. domi::KernelDefWithHandle *kernel_with_handle = task_def->mutable_kernel_with_handle();
  251. kernel_with_handle->set_original_kernel_key("");
  252. kernel_with_handle->set_node_info("");
  253. kernel_with_handle->set_block_dim(32);
  254. kernel_with_handle->set_args_size(64);
  255. string args(64, '1');
  256. kernel_with_handle->set_args(args.data(), 64);
  257. domi::KernelContext *context = kernel_with_handle->mutable_context();
  258. context->set_op_index(0);
  259. context->set_kernel_type(2); // ccKernelType::TE
  260. uint16_t args_offset[9] = {0};
  261. context->set_args_offset(args_offset, 9 * sizeof(uint16_t));
  262. OpDescPtr op_desc = CreateOpDesc("Add", "Add");
  263. std::vector<char> kernelBin;
  264. TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
  265. op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
  266. std::string kernel_name("kernel/Add");
  267. AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name);
  268. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  269. NodePtr node = graph->AddNode(op_desc);
  270. GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
  271. HybridModel hybrid_model(ge_root_model);
  272. HybridModelBuilder hybrid_model_builder(hybrid_model);
  273. ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), SUCCESS);
  274. }
  275. TEST_F(UtestGeHybrid, init_weight_success) {
  276. NpuMemoryAllocator::allocators_.emplace(make_pair(0, nullptr));
  277. // make graph with sub_graph
  278. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("root_graph");
  279. OpDescPtr op_desc = CreateOpDesc("if", IF);
  280. NodePtr node = graph->AddNode(op_desc);
  281. // make sub graph
  282. ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("if_sub_graph");
  283. OpDescPtr const_op_desc = CreateOpDesc("const", CONSTANT);
  284. vector<int64_t> dims_vec_0 = {2, 1, 4, 1, 2};
  285. vector<int32_t> data_vec_0 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  286. GeTensorDesc tensor_desc_0(GeShape(dims_vec_0), FORMAT_NCHW, DT_INT32);
  287. (void)TensorUtils::SetRealDimCnt(tensor_desc_0, dims_vec_0.size());
  288. ConstGeTensorPtr constTensor_0 =
  289. std::make_shared<GeTensor>(tensor_desc_0, (uint8_t *)&data_vec_0[0], data_vec_0.size() * sizeof(int32_t));
  290. AttrUtils::SetTensor(const_op_desc, ge::ATTR_NAME_WEIGHTS, constTensor_0);
  291. const_op_desc->AddOutputDesc(tensor_desc_0);
  292. NodePtr const_node = sub_graph->AddNode(const_op_desc);
  293. graph->AddSubgraph("sub", sub_graph);
  294. GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
  295. GeModelPtr ge_sub_model = make_shared<GeModel>();
  296. //Buffer weight_buffer = Buffer(128,0);
  297. //ge_sub_model->SetWeight(weight_buffer);
  298. ge_root_model->SetSubgraphInstanceNameToModel("sub",ge_sub_model);
  299. HybridModel hybrid_model(ge_root_model);
  300. HybridModelBuilder hybrid_model_builder(hybrid_model);
  301. auto ret = hybrid_model_builder.InitWeights();
  302. ASSERT_EQ(ret,SUCCESS);
  303. Buffer weight_buffer = Buffer(128,0);
  304. ge_sub_model->SetWeight(weight_buffer);
  305. ret = hybrid_model_builder.InitWeights();
  306. ASSERT_EQ(ret,PARAM_INVALID);
  307. }
  308. TEST_F(UtestGeHybrid, hybrid_model_executor) {
  309. ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc");
  310. GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
  311. HybridModel model(root_model);
  312. model.root_graph_item_.reset(new GraphItem);
  313. HybridModel *model_ptr = &model;
  314. uint32_t device_id = 0;
  315. rtStream_t stream = nullptr;
  316. HybridModelExecutor executor(model_ptr, device_id, stream);
  317. executor.Init();
  318. }
  319. TEST_F(UtestGeHybrid, test_parse_parallel_group) {
  320. NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl",
  321. NodeExecutorManager::ExecutorType::HCCL);
  322. ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test");
  323. OpDescPtr op_desc = CreateOpDesc("AllReduce", "AllReduce");
  324. op_desc->SetId(0);
  325. ge::AttrUtils::SetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, "group_1");
  326. auto node = compute_graph->AddNode(op_desc);
  327. std::unique_ptr<NodeItem> node_item;
  328. NodeItem::Create(node, node_item);
  329. node_item->node_id = 0;
  330. op_desc->SetOpKernelLibName("ops_kernel_info_hccl");
  331. GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
  332. HybridModel model(root_model);
  333. model.root_graph_ = compute_graph;
  334. HybridModelBuilder builder(model);
  335. ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS);
  336. ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1);
  337. ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1);
  338. OpDescPtr op_desc_1 = CreateOpDesc("subgraph", "PartitionedCall");
  339. op_desc_1->AddSubgraphName("subgraph");
  340. auto node_1 = compute_graph->AddNode(op_desc_1);
  341. ComputeGraphPtr subgraph = MakeShared<ComputeGraph>("subgraph");
  342. ASSERT_EQ(NodeUtils::SetSubgraph(*node_1, 0, subgraph), GRAPH_SUCCESS);
  343. std::unique_ptr<NodeItem> node_item_1;
  344. NodeItem::Create(node_1, node_item_1);
  345. node_item_1->node_id = 1;
  346. ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS);
  347. ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1);
  348. ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 1);
  349. OpDescPtr op_desc_2 = CreateOpDesc("sub_node_1", "AllReduce");
  350. ge::AttrUtils::SetStr(op_desc_2, ATTR_NAME_PARALLEL_GROUP, "group_1");
  351. auto node_2 = subgraph->AddNode(op_desc_2);
  352. ASSERT_TRUE(node_2 != nullptr);
  353. OpDescPtr op_desc_3 = CreateOpDesc("sub_node_2", "AllReduce2");
  354. ge::AttrUtils::SetStr(op_desc_3, ATTR_NAME_PARALLEL_GROUP, "group_2");
  355. auto node_3 = subgraph->AddNode(op_desc_3);
  356. ASSERT_TRUE(node_3 != nullptr);
  357. ASSERT_EQ(builder.CollectParallelGroups(node_item_1.get()), SUCCESS);
  358. ASSERT_EQ(builder.node_to_parallel_groups_.size(), 2);
  359. ASSERT_EQ(builder.parallel_group_to_nodes_.size(), 2);
  360. ASSERT_EQ(builder.parallel_group_to_nodes_["group_1"].size(), 2);
  361. ASSERT_EQ(builder.parallel_group_to_nodes_["group_2"].size(), 1);
  362. builder.parallel_group_to_nodes_.clear();
  363. builder.node_ref_inputs_.clear();
  364. model.node_items_[node] = std::move(node_item);
  365. model.node_items_[node_1] = std::move(node_item_1);
  366. ASSERT_FALSE(model.node_items_[node]->has_observer);
  367. ASSERT_TRUE(model.node_items_[node_1]->dependents_for_execution.empty());
  368. ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS);
  369. ASSERT_TRUE(model.node_items_[node]->has_observer);
  370. ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution.size(), 1);
  371. ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution[0], node);
  372. // repeat parse
  373. ASSERT_EQ(builder.ParseDependentByParallelGroup(), SUCCESS);
  374. ASSERT_TRUE(model.node_items_[node]->has_observer);
  375. ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution.size(), 1);
  376. ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution[0], node);
  377. }
  378. TEST_F(UtestGeHybrid, unfold_subgraphs_success) {
  379. ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph");
  380. auto partitioned_call_op_desc = CreateOpDesc("partitioned_call", PARTITIONEDCALL, 3, 1);
  381. auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc);
  382. partitioned_call_op_desc->AddSubgraphName("f");
  383. partitioned_call_op_desc->SetSubgraphInstanceName(0, "sub_graph");
  384. ComputeGraphPtr sub_sub_graph1 = std::make_shared<ComputeGraph>("while_cond");
  385. {
  386. OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA);
  387. NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc);
  388. sub_sub_graph1->SetParentGraph(root_graph);
  389. root_graph->AddSubGraph(sub_sub_graph1);
  390. }
  391. ComputeGraphPtr sub_sub_graph2 = std::make_shared<ComputeGraph>("while_body");
  392. {
  393. OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA);
  394. NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc);
  395. sub_sub_graph2->SetGraphUnknownFlag(true);
  396. sub_sub_graph2->SetParentGraph(root_graph);
  397. root_graph->AddSubGraph(sub_sub_graph2);
  398. }
  399. // Will unfold to merged_graph.
  400. ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
  401. {
  402. OpDescPtr sub_graph_data1_op_desc = CreateOpDesc("data1", DATA, 1, 1);
  403. OpDescPtr sub_graph_data2_op_desc = CreateOpDesc("data2", DATA, 1, 1);
  404. OpDescPtr sub_graph_data3_op_desc = CreateOpDesc("data3", DATA, 1, 1);
  405. NodePtr sub_graph_data1_node = sub_graph->AddNode(sub_graph_data1_op_desc);
  406. NodePtr sub_graph_data2_node = sub_graph->AddNode(sub_graph_data2_op_desc);
  407. NodePtr sub_graph_data3_node = sub_graph->AddNode(sub_graph_data3_op_desc);
  408. AttrUtils::SetInt(sub_graph_data1_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 0);
  409. AttrUtils::SetInt(sub_graph_data2_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 1);
  410. AttrUtils::SetInt(sub_graph_data3_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 2);
  411. OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE, 2, 2);
  412. NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc);
  413. sub_sub_graph1->SetParentNode(sub_graph_while_node);
  414. sub_sub_graph2->SetParentNode(sub_graph_while_node);
  415. sub_graph_while_op_desc->AddSubgraphName("while_cond");
  416. sub_graph_while_op_desc->SetSubgraphInstanceName(0, "while_cond");
  417. sub_graph_while_op_desc->AddSubgraphName("while_body");
  418. sub_graph_while_op_desc->SetSubgraphInstanceName(1, "while_body");
  419. OpDescPtr sub_graph_matmul_op_desc = CreateOpDesc("matmul", MATMUL, 2, 1);
  420. NodePtr sub_graph_matmul_node = sub_graph->AddNode(sub_graph_matmul_op_desc);
  421. OpDescPtr sub_graph_output_op_desc = CreateOpDesc("output", NETOUTPUT, 1, 1);
  422. NodePtr sub_graph_output_node = sub_graph->AddNode(sub_graph_output_op_desc);
  423. GraphUtils::AddEdge(sub_graph_data1_node->GetOutDataAnchor(0), sub_graph_while_node->GetInDataAnchor(0));
  424. GraphUtils::AddEdge(sub_graph_data2_node->GetOutDataAnchor(0), sub_graph_while_node->GetInDataAnchor(1));
  425. GraphUtils::AddEdge(sub_graph_data3_node->GetOutDataAnchor(0), sub_graph_matmul_node->GetInDataAnchor(0));
  426. GraphUtils::AddEdge(sub_graph_while_node->GetOutDataAnchor(0), sub_graph_matmul_node->GetInDataAnchor(1));
  427. GraphUtils::AddEdge(sub_graph_matmul_node->GetOutDataAnchor(0), sub_graph_output_node->GetInDataAnchor(0));
  428. sub_graph->SetGraphUnknownFlag(true);
  429. sub_graph->SetParentNode(partitioned_call_node);
  430. sub_graph->SetParentGraph(root_graph);
  431. root_graph->AddSubGraph(sub_graph);
  432. }
  433. OpDescPtr graph_data1_op_desc = CreateOpDesc("data1", DATA, 1, 1);
  434. OpDescPtr graph_data2_op_desc = CreateOpDesc("data2", DATA, 1, 1);
  435. OpDescPtr graph_data3_op_desc = CreateOpDesc("data3", DATA, 1, 1);
  436. NodePtr graph_data1_node = root_graph->AddNode(graph_data1_op_desc);
  437. NodePtr graph_data2_node = root_graph->AddNode(graph_data2_op_desc);
  438. NodePtr graph_data3_node = root_graph->AddNode(graph_data3_op_desc);
  439. AttrUtils::SetInt(graph_data1_op_desc, ATTR_NAME_INDEX, 0);
  440. AttrUtils::SetInt(graph_data2_op_desc, ATTR_NAME_INDEX, 1);
  441. AttrUtils::SetInt(graph_data3_op_desc, ATTR_NAME_INDEX, 2);
  442. GraphUtils::AddEdge(graph_data1_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(0));
  443. GraphUtils::AddEdge(graph_data2_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(1));
  444. GraphUtils::AddEdge(graph_data3_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(2));
  445. ComputeGraphPtr merged_graph = nullptr;
  446. GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(root_graph);
  447. HybridModel hybrid_model(root_model);
  448. HybridModelBuilder hybrid_model_builder(hybrid_model);
  449. EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS);
  450. }
  451. TEST_F(UtestGeHybrid, TestTaskContext) {
  452. auto graph = make_shared<ComputeGraph>("graph");
  453. OpDescPtr op_desc = CreateOpDesc("Add", "Add");
  454. GeShape shape({2, 16});
  455. GeTensorDesc tensor_desc(shape);
  456. op_desc->AddInputDesc(tensor_desc);
  457. op_desc->AddInputDesc(tensor_desc);
  458. op_desc->AddOutputDesc(tensor_desc);
  459. auto node = graph->AddNode(op_desc);
  460. std::unique_ptr<NodeItem> node_item;
  461. NodeItem::Create(node, node_item);
  462. node_item->input_start = 0;
  463. node_item->output_start = 0;
  464. GraphExecutionContext execution_context;
  465. GraphItem graph_item;
  466. SubgraphContext subgraph_context(&graph_item, &execution_context);
  467. ASSERT_EQ(subgraph_context.Init(), SUCCESS);
  468. subgraph_context.all_inputs_.resize(2);
  469. subgraph_context.all_outputs_.resize(1);
  470. auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
  471. auto task_context = node_state->GetTaskContext();
  472. ASSERT_TRUE(task_context != nullptr);
  473. auto desc = task_context->MutableInputDesc(2);
  474. ASSERT_TRUE(desc == nullptr);
  475. desc = task_context->MutableOutputDesc(0);
  476. ASSERT_TRUE(desc != nullptr);
  477. ASSERT_EQ(desc->GetShape().GetDims(), shape.GetDims());
  478. GeTensorDesc output_desc;
  479. ASSERT_EQ(task_context->GetOutputDesc(0, output_desc), SUCCESS);
  480. ASSERT_EQ(output_desc.GetShape().GetDims(), shape.GetDims());
  481. desc = task_context->MutableInputDesc(0);
  482. ASSERT_TRUE(desc != nullptr);
  483. ASSERT_EQ(desc->GetShape().GetDims(), shape.GetDims());
  484. GeShape new_shape({8, 2});
  485. tensor_desc.SetShape(new_shape);
  486. task_context->UpdateInputDesc(1, tensor_desc);
  487. GeTensorDesc new_desc;
  488. ASSERT_EQ(task_context->GetInputDesc(1, new_desc), SUCCESS);
  489. ASSERT_EQ(new_desc.GetShape().GetDims(), new_shape.GetDims());
  490. }
  491. TEST_F(UtestGeHybrid, hybrid_model_executor_update_args) {
  492. auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
  493. auto graph = make_shared<ComputeGraph>("graph");
  494. OpDescPtr op_desc = CreateOpDesc("Add", "Add");
  495. GeShape shape({2, 16});
  496. GeTensorDesc tensor_desc(shape);
  497. op_desc->AddInputDesc(tensor_desc);
  498. op_desc->AddInputDesc(tensor_desc);
  499. op_desc->AddOutputDesc(tensor_desc);
  500. auto node = graph->AddNode(op_desc);
  501. std::unique_ptr<NodeItem> node_item;
  502. NodeItem::Create(node, node_item);
  503. node_item->input_start = 0;
  504. node_item->output_start = 0;
  505. GraphExecutionContext execution_context;
  506. GraphItem graph_item;
  507. SubgraphContext subgraph_context(&graph_item, &execution_context);
  508. ASSERT_EQ(subgraph_context.Init(), SUCCESS);
  509. subgraph_context.all_inputs_.resize(2);
  510. subgraph_context.all_outputs_.resize(1);
  511. auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
  512. auto task_context = node_state->GetTaskContext();
  513. int32_t buffer[1];
  514. aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer));
  515. EXPECT_NE(aicore_task->tiling_buffer_, nullptr);
  516. aicore_task->max_arg_count_ = 0;
  517. EXPECT_EQ(aicore_task->UpdateArgs(*task_context), ACL_ERROR_GE_MEMORY_OPERATE_FAILED);
  518. aicore_task->args_ = std::unique_ptr<uint8_t[]>(new uint8_t[sizeof(uintptr_t) * 2]);
  519. EXPECT_EQ(aicore_task->UpdateArgs(*task_context), SUCCESS);
  520. }
  521. TEST_F(UtestGeHybrid, hybrid_model_executor_check_shape) {
  522. HybridModelExecutor::ExecuteArgs args;
  523. GeTensorDescPtr ge_tensor = make_shared<GeTensorDesc>(GeTensorDesc());
  524. vector<int64_t> dim = {2 , 3};
  525. ge_tensor->SetShape(GeShape(dim));
  526. args.input_desc.push_back(ge_tensor);
  527. // create node
  528. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("God");
  529. OpDescPtr op_desc = std::make_shared<OpDesc>("data", DATA);
  530. GeTensorDesc tensor_desc(GeShape({2, 3}));
  531. std::vector<std::pair<int64_t, int64_t>> shape_range({std::pair<int64_t, int64_t>(1, 3),
  532. std::pair<int64_t, int64_t>(2, 4)});
  533. tensor_desc.SetShapeRange(shape_range);
  534. op_desc->AddInputDesc(tensor_desc);
  535. op_desc->AddOutputDesc(tensor_desc);
  536. NodePtr node = graph->AddNode(op_desc);
  537. std::unique_ptr<NodeItem> new_node;
  538. NodeItem::Create(node, new_node);
  539. new_node->is_dynamic = true;
  540. GraphItem graph_item;
  541. graph_item.input_nodes_.emplace_back(new_node.get());
  542. Status ret = HybridModelExecutor::CheckInputShapeByShapeRange(&graph_item, args);
  543. ASSERT_EQ(ret, ge::SUCCESS);
  544. HybridModelExecutor::ExecuteArgs args1;
  545. ret = HybridModelExecutor::CheckInputShapeByShapeRange(&graph_item, args1);
  546. ASSERT_EQ(ret, ge::INTERNAL_ERROR);
  547. HybridModelExecutor::ExecuteArgs args2;
  548. GeTensorDescPtr ge_tensor2 = make_shared<GeTensorDesc>(GeTensorDesc());
  549. vector<int64_t> dim2 = {-1 , 3};
  550. ge_tensor2->SetShape(GeShape(dim2));
  551. args2.input_desc.push_back(ge_tensor2);
  552. ret = HybridModelExecutor::CheckInputShapeByShapeRange(&graph_item, args1);
  553. ASSERT_EQ(ret, ge::INTERNAL_ERROR);
  554. HybridModelExecutor::ExecuteArgs args3;
  555. ret = HybridModelExecutor::CheckInputShapeByShapeRange(&graph_item, args3);
  556. ASSERT_EQ(ret, ge::INTERNAL_ERROR);
  557. }
  558. TEST_F(UtestGeHybrid, TestOptimizeDependenciesForConstInputs) {
  559. ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test");
  560. GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
  561. HybridModel model(root_model);
  562. model.root_graph_ = compute_graph;
  563. HybridModelBuilder builder(model);
  564. GeShape shape({2, 16});
  565. GeTensorDesc tensor_desc(shape);
  566. std::unique_ptr<NodeItem> const_node_item;
  567. {
  568. OpDescPtr const_op_desc = CreateOpDesc("Constant", "Const");
  569. const_op_desc->AddOutputDesc(tensor_desc);
  570. auto const_node = compute_graph->AddNode(const_op_desc);
  571. NodeItem::Create(const_node, const_node_item);
  572. }
  573. std::unique_ptr<NodeItem> non_const_node_item;
  574. {
  575. OpDescPtr op_desc = CreateOpDesc("Add", "Add");
  576. op_desc->AddOutputDesc(tensor_desc);
  577. auto const_node = compute_graph->AddNode(op_desc);
  578. NodeItem::Create(const_node, non_const_node_item);
  579. }
  580. std::unique_ptr<NodeItem> known_node_item;
  581. {
  582. OpDescPtr known_op_desc = CreateOpDesc("known", "PartitionedCall");
  583. known_op_desc->AddOutputDesc(tensor_desc);
  584. known_op_desc->AddOutputDesc(tensor_desc);
  585. auto known_node = compute_graph->AddNode(known_op_desc);
  586. NodeItem::Create(known_node, known_node_item);
  587. }
  588. std::unique_ptr<NodeItem> dst_node_item;
  589. {
  590. OpDescPtr known_op_desc = CreateOpDesc("SomeOp", "SomeOpType ");
  591. known_op_desc->AddOutputDesc(tensor_desc);
  592. known_op_desc->AddOutputDesc(tensor_desc);
  593. auto known_node = compute_graph->AddNode(known_op_desc);
  594. NodeItem::Create(known_node, dst_node_item);
  595. }
  596. float buffer[2 * 16];
  597. unique_ptr<TensorValue> tensor_value(new TensorValue(buffer, sizeof(buffer)));
  598. model.constant_tensors_[const_node_item->node] = std::move(tensor_value);
  599. // Case 1. connect to Const
  600. auto output_id = 1;
  601. builder.host_input_value_dependencies_[dst_node_item.get()].emplace_back(output_id, const_node_item.get());
  602. builder.host_input_value_dependencies_[dst_node_item.get()].emplace_back(0, non_const_node_item.get());
  603. dst_node_item->dependents_for_shape_inference.emplace_back(const_node_item->node);
  604. dst_node_item->dependents_for_shape_inference.emplace_back(non_const_node_item->node);
  605. ASSERT_EQ(builder.OptimizeDependenciesForConstantInputs(), SUCCESS);
  606. ASSERT_EQ(dst_node_item->dependents_for_shape_inference.size(), 1);
  607. ASSERT_EQ(dst_node_item->dependents_for_shape_inference[0], non_const_node_item->node);
  608. // Case 2. connect to known-subgraph, netoutput connect to Const
  609. builder.host_input_value_dependencies_.clear();
  610. dst_node_item->dependents_for_shape_inference.clear();
  611. builder.known_subgraph_constant_output_refs_[known_node_item.get()].emplace(output_id, const_node_item->node);
  612. builder.host_input_value_dependencies_[dst_node_item.get()].emplace_back(output_id, known_node_item.get());
  613. builder.host_input_value_dependencies_[dst_node_item.get()].emplace_back(0, non_const_node_item.get());
  614. dst_node_item->dependents_for_shape_inference.emplace_back(known_node_item->node);
  615. dst_node_item->dependents_for_shape_inference.emplace_back(non_const_node_item->node);
  616. ASSERT_EQ(builder.OptimizeDependenciesForConstantInputs(), SUCCESS);
  617. ASSERT_EQ(dst_node_item->dependents_for_shape_inference.size(), 1);
  618. ASSERT_EQ(dst_node_item->dependents_for_shape_inference[0], non_const_node_item->node);
  619. }
  620. TEST_F(UtestGeHybrid, test_key_for_kernel_bin) {
  621. auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
  622. OpDesc op_desc("Sum", "Sum");
  623. EXPECT_EQ(aicore_task->GetKeyForTbeKernel(), OP_EXTATTR_NAME_TBE_KERNEL);
  624. EXPECT_EQ(aicore_task->GetKeyForTvmMagic(), TVM_ATTR_NAME_MAGIC);
  625. EXPECT_EQ(aicore_task->GetKeyForTvmMetaData(), TVM_ATTR_NAME_METADATA);
  626. EXPECT_EQ(aicore_task->GetKeyForKernelName(op_desc), "Sum_kernelname");
  627. auto atomic_task = std::unique_ptr<hybrid::AtomicAddrCleanOpTask>(new(std::nothrow)hybrid::AtomicAddrCleanOpTask());
  628. EXPECT_EQ(atomic_task->GetKeyForTbeKernel(), EXT_ATTR_ATOMIC_TBE_KERNEL);
  629. EXPECT_EQ(atomic_task->GetKeyForTvmMagic(), ATOMIC_ATTR_TVM_MAGIC);
  630. EXPECT_EQ(atomic_task->GetKeyForTvmMetaData(), ATOMIC_ATTR_TVM_METADATA);
  631. EXPECT_EQ(atomic_task->GetKeyForKernelName(op_desc), "Sum_atomic_kernelname");
  632. }
  633. TEST_F(UtestGeHybrid, test_op_type) {
  634. auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
  635. aicore_task->op_type_ = "Add";
  636. EXPECT_EQ(aicore_task->GetOpType(), "Add");
  637. auto atomic_task = std::unique_ptr<hybrid::AtomicAddrCleanOpTask>(new(std::nothrow)hybrid::AtomicAddrCleanOpTask());
  638. EXPECT_EQ(atomic_task->GetOpType(), "DynamicAtomicAddrClean");
  639. }
  640. TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) {
  641. NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl",
  642. NodeExecutorManager::ExecutorType::HCCL);
  643. ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test");
  644. OpDescPtr op_desc = CreateOpDesc("Add", "Add");
  645. auto node = compute_graph->AddNode(op_desc);
  646. std::unique_ptr<NodeItem> node_item;
  647. NodeItem::Create(node, node_item);
  648. node_item->node_id = 0;
  649. OpDescPtr op_desc_1 = CreateOpDesc("AllReduce", "AllReduce");
  650. op_desc_1->SetOpKernelLibName("ops_kernel_info_hccl");
  651. auto node_1 = compute_graph->AddNode(op_desc_1);
  652. std::unique_ptr<NodeItem> node_item_1;
  653. NodeItem::Create(node_1, node_item_1);
  654. node_item_1->node_id = 1;
  655. node->GetOutControlAnchor()->LinkTo(node_1->GetInControlAnchor());
  656. OpDescPtr op_desc_2 = CreateOpDesc("net_output", NETOUTPUT);
  657. auto node_2 = compute_graph->AddNode(op_desc_2);
  658. std::unique_ptr<NodeItem> node_item_2;
  659. NodeItem::Create(node_2, node_item_2);
  660. node_item_2->node_id = 2;
  661. node_1->GetOutControlAnchor()->LinkTo(node_2->GetInControlAnchor());
  662. GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
  663. HybridModel model(root_model);
  664. model.root_graph_ = compute_graph;
  665. model.node_items_.emplace(node, std::move(node_item));
  666. model.node_items_.emplace(node_1, std::move(node_item_1));
  667. model.node_items_.emplace(node_2, std::move(node_item_2));
  668. HybridModelBuilder builder(model);
  669. std::vector<std::string> deps;
  670. ASSERT_EQ(builder.ParseDependentInputNodes(*model.node_items_[node_1], deps), SUCCESS);
  671. ASSERT_EQ(builder.ParseDependentInputNodes(*model.node_items_[node_2], deps), SUCCESS);
  672. ASSERT_FALSE(model.GetNodeItem(node)->has_observer);
  673. ASSERT_TRUE(model.GetNodeItem(node_1)->has_observer);
  674. ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution.size(), 0);
  675. ASSERT_EQ(model.node_items_[node_2]->dependents_for_execution.size(), 1);
  676. }
  677. TEST_F(UtestGeHybrid, TestParseDependencies) {
  678. // make graph
  679. ut::GraphBuilder graph_builder = ut::GraphBuilder("graph");
  680. auto data = graph_builder.AddNode("Data", "Data", 0, 1);
  681. auto netoutput = graph_builder.AddNode("Netoutput", "NetOutput", 1, 0);
  682. graph_builder.AddDataEdge(data, 0, netoutput, 0);
  683. auto graph = graph_builder.GetGraph();
  684. GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(graph);
  685. HybridModel model(root_model);
  686. HybridModelBuilder builder(model);
  687. std::unique_ptr<NodeItem> node_item;
  688. NodeItem::Create(netoutput, node_item);
  689. std::unique_ptr<NodeItem> node_item2;
  690. NodeItem::Create(data, node_item2);
  691. model.node_items_.emplace(data, std::move(node_item2));
  692. std::vector<std::string> deps;
  693. deps.push_back("Data");
  694. auto op_desc = netoutput->GetOpDesc();
  695. op_desc->impl_->input_name_idx_["Data"] = 0;
  696. auto data_desc = data->GetOpDesc();
  697. auto tensor = std::make_shared<GeTensor>();
  698. auto tensor_desc = data_desc->MutableInputDesc(0);
  699. AttrUtils::SetTensor(tensor_desc, "_value", tensor);
  700. std::set<NodePtr> dependent_for_shape_inference;
  701. ASSERT_EQ(builder.ParseDependencies(*node_item, deps, dependent_for_shape_inference), SUCCESS);
  702. }
  703. TEST_F(UtestGeHybrid, TestTaskExecuteAsync) {
  704. auto graph = make_shared<ComputeGraph>("graph");
  705. OpDescPtr op_desc = CreateOpDesc("Add", "Add");
  706. GeShape shape({2, 16});
  707. GeTensorDesc tensor_desc(shape);
  708. op_desc->AddInputDesc(tensor_desc);
  709. op_desc->AddInputDesc(tensor_desc);
  710. op_desc->AddOutputDesc(tensor_desc);
  711. auto node = graph->AddNode(op_desc);
  712. std::unique_ptr<NodeItem> node_item;
  713. NodeItem::Create(node, node_item);
  714. node_item->input_start = 0;
  715. node_item->output_start = 0;
  716. GraphExecutionContext execution_context;
  717. GraphItem graph_item;
  718. SubgraphContext subgraph_context(&graph_item, &execution_context);
  719. ASSERT_EQ(subgraph_context.Init(), SUCCESS);
  720. subgraph_context.all_inputs_.resize(2);
  721. subgraph_context.all_outputs_.resize(1);
  722. auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
  723. auto task_context = *node_state->GetTaskContext();
  724. ASSERT_NE(BuildTaskUtils::GetTaskInfo(task_context), "");
  725. std::unique_ptr<AiCoreOpTask> task1(new AiCoreOpTask());
  726. std::vector<std::unique_ptr<AiCoreOpTask>> tasks;
  727. AiCoreNodeTask node_task(std::move(tasks));
  728. ASSERT_EQ(node_task.ExecuteAsync(task_context, nullptr), SUCCESS);
  729. }
  730. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示