You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cast_remove_pass_unittest.cc 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <vector>
  18. #define protected public
  19. #define private public
  20. #include "graph/passes/cast_remove_pass.h"
  21. #undef protected
  22. #undef private
  23. #include "anchor.h"
  24. #include "common/debug/log.h"
  25. #include "common/debug/memory_dumper.h"
  26. #include "common/op/attr_value_util.h"
  27. #include "common/types.h"
  28. #include "framework/common/ge_inner_error_codes.h"
  29. #include "graph/attr_value.h"
  30. #include "graph/debug/ge_attr_define.h"
  31. #include "inc/pass_manager.h"
  32. #include "graph_builder_utils.h"
  33. #include <string>
  34. #include <iostream>
  35. #include <vector>
  36. #include "opskernel_manager/ops_kernel_manager.h"
  37. #include "omg/omg_inner_types.h"
  38. using namespace testing;
  39. using namespace ge;
  40. using namespace std;
  41. class UtestGraphPassesCastRemovePass : public testing::Test {
  42. protected:
  43. void SetUp() {}
  44. void TearDown() {}
  45. };
  46. // case1:no net_out_put_node
  47. // TEST_F(UtestGraphPassesCastRemovePass, DoFuseProcess) {
  48. // std::vector<NodePtr> nodes_to_fuse;
  49. // auto builder = ut::GraphBuilder("g1");
  50. // auto data = builder.AddNode("data", DATA, 1, 1);
  51. // auto cast1 = builder.AddNode("cast1", CAST, 1, 1);
  52. // cast1->GetOpDesc()->MutableOutputDesc(0)->SetDataType(DT_FLOAT16);
  53. // auto trans = builder.AddNode("trans", TRANSPOSE, 1, 1, FORMAT_NCHW, DT_FLOAT16);
  54. // auto cast2 = builder.AddNode("cast2", CAST, 1, 1);
  55. // cast2->GetOpDesc()->MutableInputDesc(0)->SetDataType(DT_FLOAT16);
  56. // auto net = builder.AddNode("netout", NETOUTPUT, 1, 1);
  57. // builder.AddDataEdge(data, 0, cast1, 0);
  58. // builder.AddDataEdge(cast1, 0, trans, 0);
  59. // builder.AddDataEdge(trans, 0, cast2, 0);
  60. // builder.AddDataEdge(cast2, 0, net, 0);
  61. // ComputeGraphPtr compute_graph = builder.GetGraph();
  62. // map<string, string> options;
  63. // CastRemovePass cast_remove_pass;
  64. // DataType type = DT_FLOAT;
  65. // nodes_to_fuse.emplace_back(cast1);
  66. // nodes_to_fuse.emplace_back(trans);
  67. // nodes_to_fuse.emplace_back(cast2);
  68. // OpsKernelManager ops_kernel_manager;
  69. // cast_remove_pass.DoFuse(ops_kernel_manager, type, nodes_to_fuse);
  70. // EXPECT_EQ(compute_graph->GetAllNodesSize(),5);
  71. // std::vector<size_t> to_be_deleted_cast_index;
  72. // to_be_deleted_cast_index.emplace_back(0);
  73. // to_be_deleted_cast_index.emplace_back(2);
  74. // (void)cast_remove_pass.DoRemoveCast(to_be_deleted_cast_index, nodes_to_fuse);
  75. // EXPECT_EQ(compute_graph->GetAllNodesSize(),3);
  76. // }
  77. TEST_F(UtestGraphPassesCastRemovePass, DoFuseProcess) {
  78. std::vector<NodePtr> nodes_to_fuse;
  79. auto builder = ut::GraphBuilder("g1");
  80. auto data = builder.AddNode("data", DATA, 1, 1);
  81. auto cast1 = builder.AddNode("cast1", CAST, 1, 1);
  82. cast1->GetOpDesc()->MutableOutputDesc(0)->SetDataType(DT_FLOAT16);
  83. auto trans = builder.AddNode("trans", TRANSPOSE, 1, 1, FORMAT_NCHW, DT_FLOAT16);
  84. auto cast2 = builder.AddNode("cast2", CAST, 1, 1);
  85. cast2->GetOpDesc()->MutableInputDesc(0)->SetDataType(DT_FLOAT16);
  86. auto net = builder.AddNode("netout", NETOUTPUT, 1, 1);
  87. builder.AddDataEdge(data, 0, cast1, 0);
  88. builder.AddDataEdge(cast1, 0, trans, 0);
  89. builder.AddDataEdge(trans, 0, cast2, 0);
  90. builder.AddDataEdge(cast2, 0, net, 0);
  91. ComputeGraphPtr compute_graph = builder.GetGraph();
  92. map<string, string> options;
  93. CastRemovePass cast_remove_pass;
  94. DataType type = DT_FLOAT;
  95. nodes_to_fuse.emplace_back(cast1);
  96. nodes_to_fuse.emplace_back(trans);
  97. nodes_to_fuse.emplace_back(cast2);
  98. cast_remove_pass.RemoveCast(type, nodes_to_fuse);
  99. EXPECT_EQ(compute_graph->GetAllNodesSize(),3);
  100. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示