You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

operator_reg.h 37 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_
  17. #define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_
  18. #include <functional>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include "graph/operator.h"
  23. #include "graph/operator_factory.h"
  24. #include "graph/tensor.h"
  25. #include "graph/types.h"
  26. #include "graph/graph.h"
  27. namespace ge {
  28. using std::function;
  29. using std::string;
  30. using std::vector;
  31. #define ATTR_String(x, ...) \
  32. graphStatus get_attr_##x(AscendString &ret) const { \
  33. string ret_str = __VA_ARGS__; \
  34. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  35. ret = AscendString(ret_str.c_str()); \
  36. } \
  37. return GRAPH_SUCCESS; \
  38. } \
  39. _THIS_TYPE &set_attr_##x(const char *v) { \
  40. Operator::SetAttr(#x, v); \
  41. return *this; \
  42. } \
  43. _THIS_TYPE &set_attr_##x(const function<AscendString()> &v) { return *this; }
  44. #define ATTR_ListString(x, ...) \
  45. graphStatus get_attr_##x(vector<AscendString> &ret) const { \
  46. vector<string> ret_strs = __VA_ARGS__; \
  47. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  48. for (auto &ret_str : ret_strs) { \
  49. ret.emplace_back(ret_str.c_str()); \
  50. } \
  51. } \
  52. return GRAPH_SUCCESS; \
  53. } \
  54. _THIS_TYPE &set_attr_##x(const vector<AscendString> &v) { \
  55. Operator::SetAttr(#x, v); \
  56. return *this; \
  57. } \
  58. _THIS_TYPE &set_attr_##x(const function<vector<AscendString>()> &v) { \
  59. return *this; }
  60. #define ATTR_AscendString(x, ...) \
  61. graphStatus get_attr_##x(AscendString &ret) const { \
  62. AscendString ret_str = __VA_ARGS__; \
  63. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  64. ret = AscendString(ret_str.c_str()); \
  65. } \
  66. return GRAPH_SUCCESS; \
  67. }
  68. #define ATTR_ListAscendString(x, ...) \
  69. graphStatus get_attr_##x(vector<AscendString> &ret) const { \
  70. vector<AscendString> ret_strs = __VA_ARGS__; \
  71. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  72. for (auto &ret_str : ret_strs) { \
  73. if (ret_str.GetString() != nullptr) { \
  74. ret.emplace_back(ret_str.GetString()); \
  75. } \
  76. } \
  77. } \
  78. return GRAPH_SUCCESS; \
  79. }
  80. #define ATTR_Int(x, ...)
  81. #define ATTR_Float(x, ...)
  82. #define ATTR_Bool(x, ...)
  83. #define ATTR_Tensor(x, ...)
  84. #define ATTR_Type(x, ...)
  85. #define ATTR_NamedAttrs(x, ...)
  86. #define ATTR_ListInt(x, ...)
  87. #define ATTR_ListFloat(x, ...)
  88. #define ATTR_ListBool(x, ...)
  89. #define ATTR_ListTensor(x, ...)
  90. #define ATTR_Bytes(x, ...)
  91. #define ATTR_ListListInt(x, ...)
  92. #define ATTR_ListType(x, ...)
  93. #define ATTR_ListNamedAttrs(x, ...)
  94. #define REQUIRED_ATTR_String(x) \
  95. graphStatus get_attr_##x(AscendString &ret) const { \
  96. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  97. return GRAPH_FAILED; \
  98. } \
  99. return GRAPH_SUCCESS; \
  100. } \
  101. _THIS_TYPE &set_attr_##x(const char *v) { \
  102. Operator::SetAttr(#x, v); \
  103. return *this; \
  104. } \
  105. _THIS_TYPE &set_attr_##x(const function<AscendString()> &v) { return *this; }
  106. #define REQUIRED_ATTR_ListString(x) \
  107. graphStatus get_attr_##x(vector<AscendString> &ret) const { \
  108. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  109. return GRAPH_FAILED; \
  110. } \
  111. return GRAPH_SUCCESS; \
  112. } \
  113. _THIS_TYPE &set_attr_##x(const vector<AscendString> &v) { \
  114. Operator::SetAttr(#x, v); \
  115. return *this; \
  116. } \
  117. _THIS_TYPE &set_attr_##x(const function<vector<AscendString>()> &v) { \
  118. return *this; }
  119. #define REQUIRED_ATTR_AscendString(x) \
  120. graphStatus get_attr_##x(AscendString &ret) const { \
  121. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  122. return GRAPH_FAILED \
  123. } \
  124. return GRAPH_SUCCESS; \
  125. }
  126. #define REQUIRED_ATTR_ListAscendString(x) \
  127. graphStatus get_attr_##x(vector<AscendString> &ret) const { \
  128. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  129. return GRAPH_FAILED; \
  130. } \
  131. return GRAPH_SUCCESS; \
  132. }
  133. #define REQUIRED_ATTR_Int(x)
  134. #define REQUIRED_ATTR_Float(x)
  135. #define REQUIRED_ATTR_Bool(x)
  136. #define REQUIRED_ATTR_Tensor(x)
  137. #define REQUIRED_ATTR_Type(x)
  138. #define REQUIRED_ATTR_NamedAttrs(x)
  139. #define REQUIRED_ATTR_ListInt(x)
  140. #define REQUIRED_ATTR_ListFloat(x)
  141. #define REQUIRED_ATTR_ListBool(x)
  142. #define REQUIRED_ATTR_ListTensor(x)
  143. #define REQUIRED_ATTR_Bytes(x)
  144. #define REQUIRED_ATTR_ListListInt(x)
  145. #define REQUIRED_ATTR_ListType(x)
  146. #define REQUIRED_ATTR_ListNamedAttrs(x)
  147. class OpReg {
  148. public:
  149. OpReg &N() { return *this; }
  150. OpReg &ATTR() { return *this; }
  151. OpReg &REQUIRED_ATTR() { return *this; }
  152. OpReg &INPUT() { return *this; }
  153. OpReg &OPTIONAL_INPUT() { return *this; }
  154. OpReg &OUTPUT() { return *this; }
  155. OpReg &GRAPH() { return *this; }
  156. OpReg &DYNAMIC_GRAPH() { return *this; }
  157. OpReg &INFER_SHAPE_AND_TYPE() { return *this; }
  158. };
  159. #define REG_OP(x) \
  160. namespace op { \
  161. class x : public Operator { \
  162. typedef x _THIS_TYPE; \
  163. \
  164. public: \
  165. ATTRIBUTED_DEPRECATED(x(const char *)) \
  166. explicit x(const string &name) : Operator(name.c_str(), #x) { __##x(); } \
  167. explicit x(const char *name) : Operator(name, #x) { __##x(); } \
  168. explicit x(const AscendString &name) : Operator(name, #x) { \
  169. __##x(); } \
  170. x() : Operator(#x) { __##x(); } \
  171. \
  172. private: \
  173. void __##x() { \
  174. OpReg()
  175. #define ATTR(x, Type, ...) \
  176. N(); \
  177. __attr_##x(); \
  178. } \
  179. \
  180. public: \
  181. ATTRIBUTED_DEPRECATED(static const void name_attr_##x(AscendString &)) \
  182. static const string name_attr_##x() { return #x; } \
  183. static const void name_attr_##x(AscendString &attr) { \
  184. attr = AscendString(#x); \
  185. } \
  186. ATTR_##Type(x, __VA_ARGS__) \
  187. Op##Type get_attr_##x() const { \
  188. Op##Type ret = __VA_ARGS__; \
  189. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  190. return ret; \
  191. } \
  192. return ret; \
  193. } \
  194. _THIS_TYPE &set_attr_##x(const Op##Type &v) { \
  195. Operator::SetAttr(#x, v); \
  196. return *this; \
  197. } \
  198. _THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \
  199. \
  200. private: \
  201. void __attr_##x() { \
  202. Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \
  203. string attr_name(#x); \
  204. (void)OpReg()
  205. #define REQUIRED_ATTR(x, Type) \
  206. N(); \
  207. __required_attr_##x(); \
  208. } \
  209. \
  210. public: \
  211. ATTRIBUTED_DEPRECATED(static const void name_attr_##x(AscendString &)) \
  212. static const string name_attr_##x() { return #x; } \
  213. static const void name_attr_##x(AscendString &attr_name) { \
  214. attr_name = AscendString(#x); \
  215. } \
  216. REQUIRED_ATTR_##Type(x) \
  217. Op##Type get_attr_##x() const { \
  218. Op##Type ret; \
  219. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  220. return ret; \
  221. } \
  222. return ret; \
  223. } \
  224. _THIS_TYPE &set_attr_##x(const Op##Type &v) { \
  225. Operator::SetAttr(#x, v); \
  226. return *this; \
  227. } \
  228. _THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \
  229. \
  230. private: \
  231. void __required_attr_##x() { \
  232. Operator::RequiredAttrRegister(#x); \
  233. string attr_name(#x); \
  234. (void)OpReg()
  235. #define INPUT(x, t) \
  236. N(); \
  237. __input_##x(); \
  238. } \
  239. \
  240. public: \
  241. ATTRIBUTED_DEPRECATED(static const void name_in_##x(AscendString &)) \
  242. static const string name_in_##x() { return #x; } \
  243. static const void name_in_##x(AscendString &name) { \
  244. name = AscendString(#x); \
  245. } \
  246. ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_input_##x##_by_name(Operator &, const char *)) \
  247. _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \
  248. Operator::SetInput(#x, v, srcName.c_str()); \
  249. return *this; \
  250. } \
  251. _THIS_TYPE &set_input_##x##_by_name(Operator &v, const char *srcName) { \
  252. Operator::SetInput(#x, v, srcName); \
  253. return *this; \
  254. } \
  255. _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \
  256. Operator::SetInput(#x, v, index); \
  257. return *this; \
  258. } \
  259. _THIS_TYPE &set_input_##x(Operator &v) { \
  260. Operator::SetInput(#x, v); \
  261. return *this; \
  262. } \
  263. TensorDesc get_input_desc_##x() const { return Operator::GetInputDescByName(#x); } \
  264. graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \
  265. return Operator::UpdateInputDesc(#x, tensorDesc); \
  266. } \
  267. \
  268. private: \
  269. void __input_##x() { \
  270. Operator::InputRegister(#x); \
  271. (void)OpReg()
  272. #define OPTIONAL_INPUT(x, t) \
  273. N(); \
  274. __optional_input_##x(); \
  275. } \
  276. \
  277. public: \
  278. ATTRIBUTED_DEPRECATED(static const void name_in_##x(AscendString &)) \
  279. static const string name_in_##x() { return #x; } \
  280. static const void name_in_##x(AscendString &name) { \
  281. name = AscendString(#x); \
  282. } \
  283. _THIS_TYPE &set_input_##x(Operator &v) { \
  284. Operator::SetInput(#x, v); \
  285. return *this; \
  286. } \
  287. ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_input_##x##_by_name(Operator &, const char *)) \
  288. _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \
  289. Operator::SetInput(#x, v, srcName.c_str()); \
  290. return *this; \
  291. } \
  292. _THIS_TYPE &set_input_##x##_by_name(Operator &v, const char *srcName) { \
  293. Operator::SetInput(#x, v, srcName); \
  294. return *this; \
  295. } \
  296. _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \
  297. Operator::SetInput(#x, v, index); \
  298. return *this; \
  299. } \
  300. TensorDesc get_input_desc_##x() const { return Operator::GetInputDescByName(#x); } \
  301. graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \
  302. return Operator::UpdateInputDesc(#x, tensorDesc); \
  303. } \
  304. \
  305. private: \
  306. void __optional_input_##x() { \
  307. Operator::OptionalInputRegister(#x); \
  308. (void)OpReg()
  309. #define OUTPUT(x, t) \
  310. N(); \
  311. __out_##x(); \
  312. } \
  313. \
  314. public: \
  315. ATTRIBUTED_DEPRECATED(static const void name_out_##x(AscendString &)) \
  316. static const string name_out_##x() { return #x; } \
  317. static const void name_out_##x(AscendString &name) { \
  318. name = AscendString(#x); \
  319. } \
  320. TensorDesc get_output_desc_##x() const { return Operator::GetOutputDescByName(#x); } \
  321. graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \
  322. return Operator::UpdateOutputDesc(#x, tensorDesc); \
  323. } \
  324. \
  325. private: \
  326. void __out_##x() { \
  327. Operator::OutputRegister(#x); \
  328. (void)OpReg()
  329. #define DYNAMIC_INPUT(x, t) \
  330. N(); \
  331. __dy_input_##x(); \
  332. } \
  333. \
  334. public: \
  335. _THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \
  336. Operator::DynamicInputRegister(#x, num, isPushBack); \
  337. return *this; \
  338. } \
  339. _THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \
  340. Operator::DynamicInputRegisterByIndex(#x, num, index); \
  341. return *this; \
  342. } \
  343. TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { \
  344. return Operator::GetDynamicInputDesc(#x, index); \
  345. } \
  346. graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \
  347. return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \
  348. } \
  349. _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \
  350. Operator::SetInput(#x, dstIndex, v); \
  351. return *this; \
  352. } \
  353. ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_dynamic_input_##x(uint32_t, Operator &, const char *))\
  354. _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \
  355. Operator::SetInput(#x, dstIndex, v, srcName.c_str()); \
  356. return *this; \
  357. } \
  358. _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const char *srcName) { \
  359. Operator::SetInput(#x, dstIndex, v, srcName); \
  360. return *this; \
  361. } \
  362. \
  363. private: \
  364. void __dy_input_##x() { \
  365. Operator::DynamicInputRegister(#x, 0, true); \
  366. (void)OpReg()
  367. #define DYNAMIC_OUTPUT(x, t) \
  368. N(); \
  369. __dy_output_##x(); \
  370. } \
  371. \
  372. public: \
  373. _THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \
  374. Operator::DynamicOutputRegister(#x, num, isPushBack); \
  375. return *this; \
  376. } \
  377. TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { \
  378. return Operator::GetDynamicOutputDesc(#x, index); \
  379. } \
  380. graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \
  381. return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \
  382. } \
  383. \
  384. private: \
  385. void __dy_output_##x() { \
  386. Operator::DynamicOutputRegister(#x, 0, true); \
  387. (void)OpReg()
  388. #define GRAPH(x) \
  389. N(); \
  390. __graph_##x(); \
  391. } \
  392. \
  393. public: \
  394. ATTRIBUTED_DEPRECATED(static const void name_graph_##x(AscendString &)) \
  395. static const string name_graph_##x() { return #x; } \
  396. static const void name_graph_##x(AscendString &name) { \
  397. name = AscendString(#x); \
  398. } \
  399. SubgraphBuilder get_subgraph_builder_##x() const { \
  400. return Operator::GetSubgraphBuilder(#x); \
  401. } \
  402. _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \
  403. Operator::SetSubgraphBuilder(#x, 0, v); \
  404. return *this; \
  405. } \
  406. Graph get_subgraph_##x() const { \
  407. return Operator::GetSubgraph(#x); \
  408. } \
  409. \
  410. private: \
  411. void __graph_##x() { \
  412. Operator::SubgraphRegister(#x, false); \
  413. Operator::SubgraphCountRegister(#x, 1); \
  414. (void)OpReg()
  415. #define DYNAMIC_GRAPH(x) \
  416. N(); \
  417. __graph_##x(); \
  418. } \
  419. \
  420. public: \
  421. ATTRIBUTED_DEPRECATED(static const void name_graph_##x(AscendString &)) \
  422. static const string name_graph_##x() { return #x; } \
  423. static const void name_graph_##x(AscendString &name) { \
  424. name = AscendString(#x); \
  425. } \
  426. _THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \
  427. Operator::SubgraphCountRegister(#x, num); \
  428. return *this; \
  429. } \
  430. SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \
  431. return Operator::GetDynamicSubgraphBuilder(#x, index); \
  432. } \
  433. Graph get_dynamic_subgraph_##x(uint32_t index) const { \
  434. return Operator::GetDynamicSubgraph(#x, index); \
  435. } \
  436. _THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index,const SubgraphBuilder &v) { \
  437. Operator::SetSubgraphBuilder(#x, index, v); \
  438. return *this; \
  439. } \
  440. \
  441. private: \
  442. void __graph_##x() { \
  443. Operator::SubgraphRegister(#x, true); \
  444. (void)OpReg()
  445. #define PASTE(g_register, y) g_register##y
  446. #define __OP_END_IMPL__(x, y) \
  447. N(); \
  448. } \
  449. static_assert( \
  450. std::is_same<x, _THIS_TYPE>::value, \
  451. "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \
  452. } \
  453. ; \
  454. static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const AscendString &name) { return x(name); }); \
  455. }
  456. #define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__)
  457. // Specialized shape inferencer macro
  458. #define IMPLEMT_INFERFUNC(op_name, func_name) \
  459. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op)
  460. #define IMPLEMT_COMMON_INFERFUNC(func_name) \
  461. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op)
  462. #define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \
  463. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op)
  464. // Specialized verifier macro
  465. #define IMPLEMT_VERIFIER(op_name, func_name) \
  466. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op)
  467. #define INFER_VERIFY_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); }
  468. #define COMMON_INFER_VERIFY_FUNC(x) [&](Operator &v) { return x(v); }
  469. #define INFER_FORMAT_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); }
  470. #define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x)
  471. #define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x)
  472. // Infer format func register
  473. #define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \
  474. static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x)
  475. // Shape inferencer & verifier register macro
  476. #define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__)
  477. #define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC(x), __COUNTER__)
  478. #define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__)
  479. // Infer format func reg
  480. #define INFER_FORMAT_FUNC_REG(op_name, x) \
  481. __INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__)
  482. // Common shape inferencer
  483. #define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \
  484. [](Operator op)->graphStatus { \
  485. auto x_shape = op.GetInputDescByName(in_name).GetShape().GetDims(); \
  486. auto x_type = op.GetInputDescByName(in_name).GetDataType(); \
  487. TensorDesc op_output_desc = op.GetOutputDescByName(out_name); \
  488. op_output_desc.SetShape(ge::Shape(x_shape)); \
  489. op_output_desc.SetOriginShape(ge::Shape(x_shape)); \
  490. op_output_desc.SetDataType(x_type); \
  491. return op.UpdateOutputDesc(out_name, op_output_desc); \
  492. }
  493. graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape,
  494. const function<vector<int64_t>()> &get_in2_shape,
  495. const function<void(const vector<int64_t> &y_shape)> &set_out_shape);
  496. #define BROADCAST_INFER(in1_name, in2_name, out_name) \
  497. [](Operator op) -> graphStatus { \
  498. return BroadCastInfer([&]() { return op.GetInputDescByName(in1_name).GetShape().GetDims(); }, \
  499. [&]() { return op.GetInputDescByName(in2_name).GetShape().GetDims(); }, \
  500. [&](const vector<int64_t> &y_shape) { \
  501. TensorDesc op_output_desc = op.GetOutputDescByName(out_name); \
  502. op_output_desc.SetShape(ge::Shape(y_shape)); \
  503. (void)op.UpdateOutputDesc(out_name, op_output_desc);}); \
  504. }
  505. } // namespace ge
  506. #endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示