You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lookup_ops.h 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_LOOKUP_OPS_H_
  17. #define GE_OP_LOOKUP_OPS_H_
  18. #include "graph/operator_reg.h"
  19. namespace ge {
  20. /**
  21. *@brief Replaces the contents of the table with the specified keys and values.
  22. *@par Inputs:
  23. *The dtype of input handle must be resource. Inputs include: \n
  24. *@li handle: A Tensor of type resource. Handle to the table.
  25. *@li keys: A Tensor. Any shape. Keys to look up.
  26. *@li values: A Tensor. Values to associate with keys.
  27. *@par Third-party framework compatibility.
  28. *Compatible with tensorflow LookupTableImport operator.
  29. */
  30. REG_OP(LookupTableImport)
  31. .INPUT(handle, TensorType({DT_RESOURCE}))
  32. .INPUT(keys, TensorType({DT_STRING, DT_INT32, DT_INT64}))
  33. .INPUT(values, TensorType({DT_BOOL, DT_DOUBLE, \
  34. DT_FLOAT, DT_INT32, DT_INT64, DT_STRING}))
  35. .OP_END_FACTORY_REG(LookupTableImport)
  36. /**
  37. *@brief Updates the table to associates keys with values.
  38. *@par Inputs:
  39. *The dtype of input handle must be resource. Inputs include: \n
  40. *@li handle: A Tensor of type resource. Handle to the table.
  41. *@li keys: A Tensor. Any shape. Keys to look up.
  42. *@li values: A Tensor. Values to associate with keys.
  43. *@attention Constraints: \n
  44. *@li The tensor keys must be of the same type as the keys of the table. \n
  45. *@li The tensor values must be of the type of the table values. \n
  46. *@par Third-party framework compatibility.
  47. *Compatible with tensorflow LookupTableInsert operator.
  48. */
  49. REG_OP(LookupTableInsert)
  50. .INPUT(handle, TensorType({DT_RESOURCE}))
  51. .INPUT(keys, TensorType({DT_STRING, DT_INT32, DT_INT64}))
  52. .INPUT(values, TensorType({DT_BOOL, DT_DOUBLE, DT_FLOAT, \
  53. DT_INT32, DT_INT64, DT_STRING}))
  54. .OP_END_FACTORY_REG(LookupTableInsert)
  55. /**
  56. *@brief Outputs all keys and values in the table.
  57. *@par Inputs:
  58. *The dtype of input handle must be resource. Inputs include: \n
  59. *handle: A Tensor of type resource. Handle to the table.
  60. *@par Attributes:
  61. *@li Tkeys: A DType.
  62. *@li Tvalues: A DType.
  63. *@par Outputs:
  64. *@li keys: A Tensor of type Tkeys.
  65. *@li values: A Tensor of type Tvalues.
  66. *@par Third-party framework compatibility.
  67. *Compatible with tensorflow LookupTableExport operator.
  68. */
  69. REG_OP(LookupTableExport)
  70. .INPUT(handle, TensorType({DT_RESOURCE}))
  71. .OUTPUT(keys, TensorType({DT_INT32, DT_INT64, DT_STRING}))
  72. .OUTPUT(values, TensorType({DT_BOOL, DT_DOUBLE, DT_FLOAT, \
  73. DT_INT32, DT_INT64, DT_STRING}))
  74. .REQUIRED_ATTR(Tkeys, Type)
  75. .REQUIRED_ATTR(Tvalues, Type)
  76. .OP_END_FACTORY_REG(LookupTableExport)
  77. /**
  78. *@brief Computes the number of elements in the given table.
  79. *@par Inputs:
  80. *The dtype of input handle must be resource. Inputs include: \n
  81. *handle: A Tensor of type resource. Handle to the table.
  82. *@par Outputs:
  83. *size: A Tensor of type int64.
  84. *@par Third-party framework compatibility.
  85. *Compatible with tensorflow LookupTableSize operator.
  86. */
  87. REG_OP(LookupTableSize)
  88. .INPUT(handle, TensorType({DT_RESOURCE}))
  89. .OUTPUT(size, TensorType({DT_INT64}))
  90. .OP_END_FACTORY_REG(LookupTableSize)
  91. /**
  92. *@brief Looks up keys in a table, outputs the corresponding values.
  93. *@par Inputs:
  94. *The dtype of input handle must be resource. Inputs include: \n
  95. *@li handle: A Tensor of type resource. Handle to the table.
  96. *@li keys: A Tensor. Any shape. Keys to look up.
  97. *@li default_value: A Tensor.
  98. *@par Attributes:
  99. *Tout: Specified type of ouput values.
  100. *@par Outputs:
  101. *values: A Tensor. Has the same type as default_value.
  102. *@par Third-party framework compatibility.
  103. *Compatible with tensorflow LookupTableFind operator.
  104. */
  105. REG_OP(LookupTableFind)
  106. .INPUT(handle, TensorType({DT_RESOURCE}))
  107. .INPUT(keys, TensorType({DT_INT32, DT_INT64, DT_STRING}))
  108. .INPUT(default_value, TensorType({DT_DOUBLE, DT_FLOAT, \
  109. DT_INT32, DT_INT64, DT_STRING, DT_BOOL}))
  110. .OUTPUT(values, TensorType({DT_DOUBLE, DT_FLOAT, DT_INT32, \
  111. DT_INT64, DT_STRING, DT_BOOL}))
  112. .REQUIRED_ATTR(Tout, Type)
  113. .OP_END_FACTORY_REG(LookupTableFind)
  114. /**
  115. *@brief Creates a non-initialized hash table.
  116. *@par Attributes:
  117. *@li container: An optional string. Defaults to "". If non-empty, this table \n
  118. is placed in the given container. Otherwise, a default container is used.
  119. *@li shared_name: An optional string. Defaults to "". If non-empty, this \n
  120. table is shared under the given name across multiple sessions.
  121. *@li use_node_name_sharing: An optional bool. Defaults to False. If true and \n
  122. shared_name is empty, the table is shared using the node name.
  123. *@li key_dtype: A DType. Type of the table keys.
  124. *@li value_dtype: A DType. Type of the table values.
  125. *@par Outputs:
  126. *handle: A Tensor of type resource. Handle to the table.
  127. *@attention Constraints: \n
  128. *The implementation for HashTable on Ascend uses ai cpu, with bad performance. \n
  129. *@par Third-party framework compatibility.
  130. *Compatible with tensorflow HashTable operator.
  131. */
  132. REG_OP(HashTable)
  133. .OUTPUT(handle, TensorType({DT_RESOURCE}))
  134. .ATTR(container, String, "")
  135. .ATTR(shared_name, String, "")
  136. .ATTR(use_node_name_sharing, Bool, false)
  137. .REQUIRED_ATTR(key_dtype, Type)
  138. .REQUIRED_ATTR(value_dtype, Type)
  139. .OP_END_FACTORY_REG(HashTable)
  140. /**
  141. *@brief Table initializer that takes two tensors for keys and values \n
  142. respectively.
  143. *@par Inputs:
  144. *The dtype of input handle must be resource. Inputs include: \n
  145. *@li handle: A Tensor of type resource. Handle to a table which will be \n
  146. initialized.
  147. *@li keys: A Tensor. Keys of type Tkey.
  148. *@li values: A Tensor. Values of type Tval.
  149. *@par Third-party framework compatibility.
  150. *Compatible with tensorflow InitializeTable operator.
  151. */
  152. REG_OP(InitializeTable)
  153. .INPUT(handle, TensorType({DT_RESOURCE}))
  154. .INPUT(keys, TensorType({DT_INT32, DT_INT64, DT_STRING}))
  155. .INPUT(values, TensorType({DT_INT32, DT_INT64, DT_FLOAT, \
  156. DT_DOUBLE, DT_BOOL, DT_STRING}))
  157. .OP_END_FACTORY_REG(InitializeTable)
  158. /**
  159. *@brief Creates an empty hash table that uses tensors as the backing store.
  160. *@par Inputs:
  161. *The input deleted_key must have the same type as empty_key. Inputs include: \n
  162. *@li empty_key: A Tensor. The key used to represent empty key buckets \n
  163. internally. Must not be used in insert or lookup operations.
  164. *@li deleted_key: A Tensor. Must have the same type as empty_key.
  165. *@par Attributes:
  166. *@li container: An optional string. Defaults to "". If non-empty, this table \n
  167. is placed in the given container. Otherwise, a default container is used.
  168. *@li shared_name: An optional string. Defaults to "". If non-empty, this \n
  169. table is shared under the given name across multiple sessions.
  170. *@li use_node_name_sharing: An optional bool. Defaults to False. If true and \n
  171. shared_name is empty, the table is shared using the node name.
  172. *@li value_dtype: A DType. Type of the table values.
  173. *@li value_shape: An optional TensorShape or list of ints. Defaults to []. \n
  174. The shape of each value.
  175. *@li initial_num_buckets: An optional int. Defaults to 131072. The initial \n
  176. number of hash table buckets. Must be a power to 2.
  177. *@li max_load_factor: An optional float. Defaults to 0.8. The maximum ratio \n
  178. between number of entries and number of buckets before growing the table. \n
  179. Must be between 0 and 1.
  180. *@par Outputs:
  181. *handle: A Tensor of type resource. Handle to the table.
  182. *@par Third-party framework compatibility.
  183. *Compatible with tensorflow MutableDenseHashTable operator.
  184. */
  185. REG_OP(MutableDenseHashTable)
  186. .INPUT(empty_key, TensorType({DT_INT32, DT_INT64, DT_STRING}))
  187. .INPUT(deleted_key, TensorType({DT_INT32, DT_INT64, DT_STRING}))
  188. .OUTPUT(handle, TensorType({DT_RESOURCE}))
  189. .ATTR(container, String, "")
  190. .ATTR(shared_name, String, "")
  191. .ATTR(use_node_name_sharing, Bool, false)
  192. .REQUIRED_ATTR(value_dtype, Type)
  193. .ATTR(value_shape, ListInt, {})
  194. .ATTR(initial_num_buckets, Int, 131072)
  195. .ATTR(max_load_factor, Float, 0.8)
  196. .OP_END_FACTORY_REG(MutableDenseHashTable)
  197. /**
  198. *@brief Creates an empty hash table.
  199. *@par Attributes:
  200. *@li container: An optional string. Defaults to "". If non-empty, this table \n
  201. is placed in the given container. Otherwise, a default container is used.
  202. *@li shared_name: An optional string. Defaults to "". If non-empty, this \n
  203. table is shared under the given name across multiple sessions.
  204. *@li use_node_name_sharing: An optional bool. Defaults to False. If true and \n
  205. shared_name is empty, the table is shared using the node name.
  206. *@li key_dtype: A DType. Type of the table keys.
  207. *@li value_dtype: A DType. Type of the table values.
  208. *@li value_shape: An optional TensorShape or list of ints. Defaults to [].
  209. *@par Outputs:
  210. *handle: A Tensor of type resource. Handle to the table.
  211. *@par Third-party framework compatibility.
  212. *Compatible with tensorflow MutableHashTableOfTensors operator.
  213. */
  214. REG_OP(MutableHashTableOfTensors)
  215. .OUTPUT(handle, TensorType({DT_RESOURCE}))
  216. .ATTR(container, String, "")
  217. .ATTR(shared_name, String, "")
  218. .ATTR(use_node_name_sharing, Bool, false)
  219. .REQUIRED_ATTR(key_dtype, Type)
  220. .REQUIRED_ATTR(value_dtype, Type)
  221. .ATTR(value_shape, ListInt, {})
  222. .OP_END_FACTORY_REG(MutableHashTableOfTensors)
  223. /**
  224. *@brief Creates an empty hash table.
  225. *@par Attributes:
  226. *@li container: An optional string. Defaults to "". If non-empty, this table \n
  227. is placed in the given container. Otherwise, a default container is used.
  228. *@li shared_name: An optional string. Defaults to "". If non-empty, this \n
  229. table is shared under the given name across multiple sessions.
  230. *@li use_node_name_sharing: An optional bool. Defaults to False. If true and \n
  231. shared_name is empty, the table is shared using the node name.
  232. *@li key_dtype: A DType. Type of the table keys.
  233. *@li value_dtype: A DType. Type of the table values.
  234. *@par Outputs:
  235. *handle: A Tensor of type resource. Handle to the table.
  236. *@par Third-party framework compatibility.
  237. *Compatible with tensorflow MutableHashTable operator.
  238. */
  239. REG_OP(MutableHashTable)
  240. .OUTPUT(handle, TensorType({DT_RESOURCE}))
  241. .ATTR(container, String, "")
  242. .ATTR(shared_name, String, "")
  243. .ATTR(use_node_name_sharing, Bool, false)
  244. .REQUIRED_ATTR(key_dtype, Type)
  245. .REQUIRED_ATTR(value_dtype, Type)
  246. .OP_END_FACTORY_REG(MutableHashTable)
  247. } // namespace ge
  248. #endif // GE_OP_LOOKUP_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示