You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_attr_value.h 11 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_GE_ATTR_VALUE_H_
  17. #define INC_GRAPH_GE_ATTR_VALUE_H_
  18. #include <iostream>
  19. #include <map>
  20. #include <memory>
  21. #include <string>
  22. #include <utility>
  23. #include <vector>
  24. #include "graph/buffer.h"
  25. #include "detail/attributes_holder.h"
  26. #include "graph/ge_error_codes.h"
  27. #include "graph/ge_tensor.h"
  28. using std::map;
  29. using std::string;
  30. using std::vector;
  31. namespace ge {
  32. class GeTensor;
  33. using GeTensorPtr = std::shared_ptr<GeTensor>;
  34. using ConstGeTensorPtr = std::shared_ptr<const GeTensor>;
  35. class ComputeGraph;
  36. using ComputeGraphPtr = std::shared_ptr<ComputeGraph>;
  37. using ConstComputeGraphPtr = std::shared_ptr<const ComputeGraph>;
  38. class GeTensorDesc;
  39. class GeAttrValueImp;
  40. class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue {
  41. public:
  42. class NamedAttrs : public AttrHolder {
  43. public:
  44. NamedAttrs();
  45. virtual ~NamedAttrs() = default;
  46. void SetName(const std::string &name);
  47. string GetName() const;
  48. GeAttrValue GetItem(const string &key) const;
  49. protected:
  50. ProtoAttrMapHelper MutableAttrMap() override;
  51. ConstProtoAttrMapHelper GetAttrMap() const override;
  52. private:
  53. // Create namedAttrs from protobuf obj
  54. NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg);
  55. GeIrProtoHelper<proto::NamedAttrs> named_attrs_;
  56. friend class GeAttrValueImp;
  57. };
  58. using INT = int64_t;
  59. using FLOAT = float;
  60. using BOOL = bool;
  61. using STR = std::string;
  62. using TENSOR = GeTensorPtr;
  63. using TENSOR_DESC = GeTensorDesc;
  64. using GRAPH = ComputeGraphPtr;
  65. using BYTES = Buffer;
  66. using NAMED_ATTRS = NamedAttrs;
  67. using DATA_TYPE = ge::DataType;
  68. using LIST_INT = vector<INT>;
  69. using LIST_FLOAT = vector<FLOAT>;
  70. using LIST_BOOL = vector<BOOL>;
  71. using LIST_STR = vector<STR>;
  72. using LIST_TENSOR = vector<TENSOR>;
  73. using LIST_TENSOR_DESC = vector<TENSOR_DESC>;
  74. using LIST_GRAPH = vector<GRAPH>;
  75. using LIST_BYTES = vector<BYTES>;
  76. using LIST_NAMED_ATTRS = vector<NAMED_ATTRS>;
  77. using LIST_LIST_INT = vector<vector<int64_t>>;
  78. using LIST_DATA_TYPE = vector<ge::DataType>;
  79. enum ValueType {
  80. VT_NONE = 0,
  81. VT_STRING,
  82. VT_FLOAT,
  83. VT_BOOL,
  84. VT_INT,
  85. VT_TENSOR_DESC,
  86. VT_TENSOR,
  87. VT_BYTES,
  88. VT_GRAPH,
  89. VT_NAMED_ATTRS,
  90. VT_LIST_LIST_INT,
  91. VT_DATA_TYPE,
  92. VT_LIST_BASE = 1000,
  93. VT_LIST_STRING = VT_LIST_BASE + VT_STRING,
  94. VT_LIST_FLOAT = VT_LIST_BASE + VT_FLOAT,
  95. VT_LIST_BOOL = VT_LIST_BASE + VT_BOOL,
  96. VT_LIST_INT = VT_LIST_BASE + VT_INT,
  97. VT_LIST_TENSOR_DESC = VT_LIST_BASE + VT_TENSOR_DESC,
  98. VT_LIST_TENSOR = VT_LIST_BASE + VT_TENSOR,
  99. VT_LIST_BYTES = VT_LIST_BASE + VT_BYTES,
  100. VT_LIST_GRAPH = VT_LIST_BASE + VT_GRAPH,
  101. VT_LIST_NAMED_ATTRS = VT_LIST_BASE + VT_NAMED_ATTRS,
  102. VT_LIST_DATA_TYPE = VT_LIST_BASE + VT_DATA_TYPE,
  103. };
  104. template <class T>
  105. struct IsAttrTypeEnable {
  106. using DT = typename std::remove_cv<T>::type;
  107. static bool const VALUE = std::is_same<INT, DT>::value || std::is_same<FLOAT, DT>::value ||
  108. std::is_same<BOOL, DT>::value || std::is_same<STR, DT>::value ||
  109. std::is_same<GRAPH, DT>::value || std::is_same<TENSOR, DT>::value ||
  110. std::is_same<TENSOR_DESC, DT>::value || std::is_same<BYTES, DT>::value ||
  111. std::is_same<NAMED_ATTRS, DT>::value || std::is_same<DATA_TYPE, DT>::value;
  112. // Not has list type of NamedAttrs
  113. static bool const LIST_VALUE = std::is_same<LIST_INT, DT>::value || std::is_same<LIST_FLOAT, DT>::value ||
  114. std::is_same<LIST_BOOL, DT>::value || std::is_same<LIST_STR, DT>::value ||
  115. std::is_same<LIST_GRAPH, DT>::value || std::is_same<LIST_TENSOR, DT>::value ||
  116. std::is_same<LIST_TENSOR_DESC, DT>::value || std::is_same<LIST_BYTES, DT>::value ||
  117. std::is_same<LIST_NAMED_ATTRS, DT>::value ||
  118. std::is_same<LIST_LIST_INT, DT>::value || std::is_same<LIST_DATA_TYPE, DT>::value;
  119. };
  120. template <typename vector_type>
  121. // To cols
  122. using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, int>::type;
  123. template <typename one_type>
  124. using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type;
  125. template <typename val_type>
  126. using enable_if_type_valid_t =
  127. typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type;
  128. template <typename seriliable_type>
  129. using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable;
  130. GeAttrValue();
  131. ~GeAttrValue() = default;
  132. // SetValue, Set initializer_list
  133. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
  134. graphStatus SetValue(std::initializer_list<DT> &&val) {
  135. T vectorVal;
  136. for (auto &item : val) {
  137. vectorVal.push_back(item);
  138. }
  139. return SetValue(vectorVal);
  140. }
  141. // SetValue, Set vector
  142. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
  143. graphStatus SetValue(const std::vector<DT> &val) {
  144. T vectorVal;
  145. for (auto item : val) {
  146. vectorVal.push_back(item);
  147. }
  148. return SetValue(vectorVal);
  149. }
  150. // SetValue, not list type
  151. template <typename T, typename DT, enable_if_one_type_valid_t<T> = 0>
  152. graphStatus SetValue(DT &&val) {
  153. return SetValue(T(std::forward<DT>(val)));
  154. }
  155. // GE_SERIALIZABLE
  156. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  157. graphStatus SetValue(const T &t) {
  158. return t.Save(*this);
  159. }
  160. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  161. graphStatus SetValue(const vector<T> &t) {
  162. vector<NamedAttrs> attrs;
  163. for (auto &item : t) {
  164. GeAttrValue val;
  165. item.Save(val);
  166. NamedAttrs attrsItem;
  167. (void)val.GetValue<NamedAttrs>(attrsItem);
  168. attrs.push_back(attrsItem);
  169. }
  170. return SetValue(attrs);
  171. }
  172. // GetValue, list value
  173. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0,
  174. typename std::enable_if<!std::is_same<DT, GeTensorPtr>::value, int>::type = 0>
  175. graphStatus GetValue(std::vector<DT> &val) const {
  176. T valGet;
  177. val.clear();
  178. auto status = GetValue(valGet);
  179. if (status != GRAPH_SUCCESS) {
  180. return status;
  181. }
  182. for (auto item : valGet) {
  183. val.push_back(item);
  184. }
  185. return GRAPH_SUCCESS;
  186. }
  187. // GetValue, not list type
  188. template <typename T, typename DT, enable_if_one_type_valid_t<T> = 0,
  189. typename std::enable_if<!std::is_same<DT, GeTensorPtr>::value, int>::type = 0>
  190. graphStatus GetValue(DT &val) const {
  191. T valGet;
  192. auto status = GetValue(valGet);
  193. if (status != GRAPH_SUCCESS) {
  194. return status;
  195. }
  196. val = DT(valGet);
  197. return GRAPH_SUCCESS;
  198. }
  199. // GE_SERIALIZABLE
  200. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  201. graphStatus GetValue(T &t) {
  202. return t.Load(*this);
  203. }
  204. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  205. graphStatus GetValue(vector<T> &t) {
  206. graphStatus status;
  207. t.clear();
  208. vector<NamedAttrs> attrs;
  209. status = this->GetValue(attrs);
  210. if (status != GRAPH_SUCCESS) {
  211. return status;
  212. }
  213. for (auto &attr : attrs) {
  214. T item;
  215. GeAttrValue val;
  216. (void)val.SetValue(attr);
  217. status = item.Load(val);
  218. if (status != GRAPH_SUCCESS) {
  219. return status;
  220. }
  221. t.push_back(item);
  222. }
  223. return GRAPH_SUCCESS;
  224. }
  225. template <typename T, typename DT, enable_if_type_valid_t<T> = 0>
  226. static GeAttrValue CreateFrom(DT &&val) {
  227. GeAttrValue valRet;
  228. (void)valRet.SetValue<T>(std::forward<DT>(val));
  229. return valRet;
  230. }
  231. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
  232. static GeAttrValue CreateFrom(std::initializer_list<DT> &&val) {
  233. GeAttrValue valRet;
  234. (void)valRet.SetValue<T>(std::move(val));
  235. return valRet;
  236. }
  237. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  238. static GeAttrValue CreateFrom(const T &val) {
  239. GeAttrValue valRet;
  240. (void)valRet.SetValue(val);
  241. return valRet;
  242. }
  243. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  244. static GeAttrValue CreateFrom(const vector<T> &val) {
  245. GeAttrValue valRet;
  246. (void)valRet.SetValue(val);
  247. return valRet;
  248. }
  249. ValueType GetValueType() const;
  250. bool IsEmpty() const;
  251. GeAttrValue Copy() const;
  252. // For map key
  253. bool operator==(const GeAttrValue &other) const { return value_ == other.value_; }
  254. graphStatus MutableTensor(GeTensorPtr &tensor);
  255. graphStatus MutableListTensor(vector<GeTensorPtr> &list_tensor);
  256. private:
  257. #define VALUE_SET_GET_DEC(DT) \
  258. graphStatus SetValue(const DT &val); \
  259. graphStatus GetValue(DT &val) const;
  260. VALUE_SET_GET_DEC(GeAttrValue::STR)
  261. VALUE_SET_GET_DEC(GeAttrValue::INT)
  262. VALUE_SET_GET_DEC(GeAttrValue::FLOAT)
  263. VALUE_SET_GET_DEC(GeAttrValue::BOOL)
  264. VALUE_SET_GET_DEC(GeTensorDesc)
  265. VALUE_SET_GET_DEC(GeAttrValue::TENSOR)
  266. VALUE_SET_GET_DEC(GeAttrValue::GRAPH)
  267. VALUE_SET_GET_DEC(BYTES)
  268. VALUE_SET_GET_DEC(NamedAttrs)
  269. VALUE_SET_GET_DEC(ge::DataType)
  270. VALUE_SET_GET_DEC(vector<GeAttrValue::STR>)
  271. VALUE_SET_GET_DEC(vector<GeAttrValue::INT>)
  272. VALUE_SET_GET_DEC(vector<GeAttrValue::FLOAT>)
  273. VALUE_SET_GET_DEC(vector<GeAttrValue::BOOL>)
  274. VALUE_SET_GET_DEC(vector<GeTensorDesc>)
  275. VALUE_SET_GET_DEC(vector<GeAttrValue::TENSOR>)
  276. VALUE_SET_GET_DEC(vector<GeAttrValue::GRAPH>)
  277. VALUE_SET_GET_DEC(vector<GeAttrValue::BYTES>)
  278. VALUE_SET_GET_DEC(vector<NamedAttrs>)
  279. VALUE_SET_GET_DEC(vector<vector<int64_t>>)
  280. VALUE_SET_GET_DEC(vector<ge::DataType>)
  281. #undef VALUE_SET_GET_DEC
  282. GeIrProtoHelper<proto::AttrDef> value_;
  283. GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val);
  284. friend class AttrHolder;
  285. friend class ModelSerializeImp;
  286. friend class OnnxUtils;
  287. };
  288. class AttrValueImpl {
  289. public:
  290. AttrValueImpl() = default;
  291. ~AttrValueImpl() = default;
  292. GeAttrValue geAttrValue_;
  293. };
  294. } // namespace ge
  295. #endif // INC_GRAPH_GE_ATTR_VALUE_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示