You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fusion_pattern.h 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_
  17. #define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_
  18. #include <iostream>
  19. #include <map>
  20. #include <memory>
  21. #include <string>
  22. #include <vector>
  23. using std::initializer_list;
  24. using std::map;
  25. using std::string;
  26. using std::vector;
  27. using namespace std;
  28. namespace fe {
  29. /** Fusion pattern
  30. * @ingroup FUSION_PASS_GROUP
  31. * Describe Pattern of Ops waiting for fusion(Op type, etc)
  32. */
  33. class FusionPattern {
  34. public:
  35. struct OpDesc;
  36. using OpDescPtr = std::shared_ptr<OpDesc>;
  37. /**
  38. * @ingroup fe
  39. * @brief description of Ops
  40. */
  41. struct OpDesc {
  42. string id; // Identifier
  43. std::vector<std::string> types; // the Op types of Ops
  44. std::vector<OpDescPtr> inputs; // all input Ops
  45. bool repeatable; // flag to show if match multiple Ops or not
  46. bool is_output; // flag to show if the op is output node
  47. };
  48. public:
  49. explicit FusionPattern(string name = "");
  50. ~FusionPattern();
  51. /** set pattern name
  52. *
  53. * @param name pattern name
  54. * @return FusionPattern
  55. */
  56. FusionPattern &SetName(const string &name);
  57. /** add Op description with unknown number of args
  58. *
  59. * @param id pattern id
  60. * @param types op type list
  61. * @return FusionPattern
  62. */
  63. FusionPattern &AddOpDesc(const string &id, const initializer_list<string> &types = {});
  64. /** add Op description with vector
  65. *
  66. * @param id pattern id
  67. * @param types op type list
  68. *
  69. * @return FusionPattern
  70. */
  71. FusionPattern &AddOpDesc(const string &id, const vector<string> &types);
  72. /** set input Ops with unknown number of args
  73. *
  74. * @param id pattern id
  75. *
  76. * @param input_ids inputs to id op
  77. *
  78. * @return FusionPattern
  79. */
  80. FusionPattern &SetInputs(const string &id, const initializer_list<string> &input_ids);
  81. /** set input Ops with unknown number of args
  82. *
  83. * @param id pattern id
  84. *
  85. * @param input_ids inputs to id op
  86. *
  87. * @return FusionPattern
  88. */
  89. FusionPattern &SetInputs(const string &id, const vector<string> &input_ids);
  90. /** set output Op
  91. *
  92. * @param id pattern id
  93. *
  94. * @return FusionPattern
  95. */
  96. FusionPattern &SetOutput(const string &id);
  97. /** build pattern and check if error exists
  98. *
  99. * @return True or False
  100. */
  101. bool Build();
  102. /** get pattern name
  103. *
  104. * @param id pattern id
  105. *
  106. * @return fusion pattern name
  107. */
  108. const string &GetName() const;
  109. /** get the OpDesc of input Ops (const)
  110. *
  111. * @param op_desc op_desc for getting inputs
  112. *
  113. * @return op_desc's iniput opdesc list
  114. */
  115. static const vector<std::shared_ptr<OpDesc>> *GetInputs(std::shared_ptr<OpDesc> op_desc);
  116. /** get the OpDesc of output Op
  117. *
  118. * @return pattern's output opdesc list
  119. */
  120. const std::shared_ptr<FusionPattern::OpDesc> GetOutput() const;
  121. /** print pattern
  122. *
  123. */
  124. void Dump() const;
  125. void GetOpDescList(vector<std::shared_ptr<OpDesc>> &op_desc_list);
  126. /** get OpDesc based on ID, return nullptr if failed
  127. *
  128. * @param id pattern id
  129. *
  130. * @return pattern's output opdesc list
  131. */
  132. std::shared_ptr<FusionPattern::OpDesc> GetOpDesc(const string &id) const;
  133. private:
  134. FusionPattern(const FusionPattern &) = default;
  135. FusionPattern &operator=(const FusionPattern &) = default;
  136. void SetError();
  137. private:
  138. string name_;
  139. vector<std::shared_ptr<OpDesc>> ops_;
  140. map<string, std::shared_ptr<OpDesc>> op_map_;
  141. std::shared_ptr<OpDesc> output_;
  142. bool has_error_;
  143. };
  144. } // namespace fe
  145. #endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示