You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_attr_value.cc 57 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/ge_attr_value.h"
  17. #include "graph/ge_tensor.h"
  18. #include "external/graph/graph.h"
  19. #include "utils/attr_utils.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "graph/model_serialize.h"
  22. #include "proto/ge_ir.pb.h"
  23. #include "detail/model_serialize_imp.h"
  24. #include "debug/ge_attr_define.h"
  25. #include "debug/ge_log.h"
  26. #include "debug/ge_util.h"
  27. using std::map;
  28. using std::string;
  29. using std::vector;
  30. namespace ge {
  31. NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); }
  32. NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg)
  33. : named_attrs_(owner, proto_msg) {} // lint !e1744
  34. void NamedAttrs::SetName(const std::string &name) {
  35. auto proto_msg = named_attrs_.GetProtoMsg();
  36. if (proto_msg != nullptr) {
  37. proto_msg->set_name(name);
  38. }
  39. }
  40. string NamedAttrs::GetName() const {
  41. auto proto_msg = named_attrs_.GetProtoMsg();
  42. if (proto_msg != nullptr) {
  43. return proto_msg->name();
  44. }
  45. return string();
  46. }
  47. GeAttrValue NamedAttrs::GetItem(const string &key) const {
  48. GeAttrValue value;
  49. (void)GetAttr(key, value);
  50. return value;
  51. }
  52. ProtoAttrMapHelper NamedAttrs::MutableAttrMap() {
  53. auto proto_msg = named_attrs_.GetProtoMsg();
  54. if (proto_msg != nullptr) {
  55. return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr());
  56. }
  57. return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr);
  58. }
  59. ConstProtoAttrMapHelper NamedAttrs::GetAttrMap() const {
  60. auto proto_msg = named_attrs_.GetProtoMsg();
  61. if (proto_msg != nullptr) {
  62. return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr());
  63. }
  64. return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr);
  65. }
  66. class GeAttrValueImp {
  67. public:
  68. static map<proto::AttrDef::ValueCase, GeAttrValue::ValueType> attr_val_one_type_map_;
  69. static map<proto::AttrDef_ListValue_ListValueType, GeAttrValue::ValueType> attr_val_list_type_map_;
  70. static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::INT val);
  71. static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::FLOAT val);
  72. static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::BOOL val);
  73. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::STR &val);
  74. static bool SetValue(proto::AttrDef &attr_def, const ConstGeTensorPtr &val);
  75. static bool SetValue(proto::AttrDef &attr_def, const GeTensor &val);
  76. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::TENSOR_DESC &val);
  77. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::BYTES &val);
  78. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::NAMED_ATTRS &val);
  79. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::GRAPH &val);
  80. static bool SetValue(proto::AttrDef &attr_def, const vector<int64_t> &val);
  81. static bool SetValue(proto::AttrDef &attr_def, const vector<int32_t> &val);
  82. static bool SetValue(proto::AttrDef &attr_def, const vector<uint32_t> &val);
  83. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_FLOAT &val);
  84. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BOOL &val);
  85. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_STR &val);
  86. static bool SetValue(proto::AttrDef &proto_attr_val, const vector<GeTensorPtr> &value);
  87. static bool SetValue(proto::AttrDef &proto_attr_val, const vector<ConstGeTensorPtr> &value);
  88. static bool SetValue(proto::AttrDef &attr_def, const vector<GeTensor> &val);
  89. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_TENSOR_DESC &val);
  90. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BYTES &val);
  91. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_NAMED_ATTRS &val);
  92. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_GRAPH &val);
  93. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::INT &val);
  94. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::FLOAT &val);
  95. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BOOL &val);
  96. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::STR &val);
  97. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::TENSOR &val);
  98. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensor &val);
  99. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  100. GeAttrValue::TENSOR_DESC &val);
  101. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BYTES &val);
  102. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  103. GeAttrValue::NAMED_ATTRS &val);
  104. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::GRAPH &val);
  105. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  106. GeAttrValue::LIST_INT &val);
  107. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  108. GeAttrValue::LIST_FLOAT &val);
  109. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  110. GeAttrValue::LIST_BOOL &val);
  111. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  112. GeAttrValue::LIST_STR &val);
  113. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  114. GeAttrValue::LIST_TENSOR &val);
  115. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, vector<GeTensor> &val);
  116. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  117. GeAttrValue::LIST_TENSOR_DESC &val);
  118. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  119. GeAttrValue::LIST_BYTES &val);
  120. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  121. GeAttrValue::LIST_NAMED_ATTRS &val);
  122. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  123. GeAttrValue::LIST_GRAPH &val);
  124. // Value will be moved
  125. static bool SetZeroCopyBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &&buffer);
  126. static bool GetZeroCopyBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &buffer);
  127. // Value will be moved
  128. static bool SetZeroCopyListBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  129. vector<Buffer> &list_buffer);
  130. static bool GetZeroCopyListBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  131. vector<Buffer> &list_buffer);
  132. static bool SetValue(proto::AttrDef &attr_def, const vector<vector<int64_t>> &value);
  133. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  134. vector<vector<int64_t>> &value);
  135. static bool SetValue(proto::AttrDef &attr_def, const vector<ge::DataType> &value);
  136. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  137. vector<ge::DataType> &value);
  138. static bool SetValue(proto::AttrDef &attr_def, const ge::DataType &value);
  139. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ge::DataType &value);
  140. };
  141. map<proto::AttrDef::ValueCase, GeAttrValue::ValueType> GeAttrValueImp::attr_val_one_type_map_ = {
  142. {proto::AttrDef::kI, GeAttrValue::VT_INT},
  143. {proto::AttrDef::kF, GeAttrValue::VT_FLOAT},
  144. {proto::AttrDef::kB, GeAttrValue::VT_BOOL},
  145. {proto::AttrDef::kS, GeAttrValue::VT_STRING},
  146. {proto::AttrDef::kT, GeAttrValue::VT_TENSOR},
  147. {proto::AttrDef::kTd, GeAttrValue::VT_TENSOR_DESC},
  148. {proto::AttrDef::kG, GeAttrValue::VT_GRAPH},
  149. {proto::AttrDef::kBt, GeAttrValue::VT_BYTES},
  150. {proto::AttrDef::kFunc, GeAttrValue::VT_NAMED_ATTRS},
  151. {proto::AttrDef::kListListInt, GeAttrValue::VT_LIST_LIST_INT},
  152. {proto::AttrDef::kDt, GeAttrValue::VT_DATA_TYPE},
  153. };
  154. map<proto::AttrDef_ListValue_ListValueType, GeAttrValue::ValueType> GeAttrValueImp::attr_val_list_type_map_ = {
  155. {proto::AttrDef_ListValue_ListValueType_VT_LIST_INT, GeAttrValue::VT_LIST_INT},
  156. {proto::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT, GeAttrValue::VT_LIST_FLOAT},
  157. {proto::AttrDef_ListValue_ListValueType_VT_LIST_BOOL, GeAttrValue::VT_LIST_BOOL},
  158. {proto::AttrDef_ListValue_ListValueType_VT_LIST_STRING, GeAttrValue::VT_LIST_STRING},
  159. {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, GeAttrValue::VT_LIST_TENSOR},
  160. {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, GeAttrValue::VT_LIST_TENSOR_DESC},
  161. {proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, GeAttrValue::VT_LIST_GRAPH},
  162. {proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, GeAttrValue::VT_LIST_BYTES},
  163. {proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, GeAttrValue::VT_LIST_NAMED_ATTRS},
  164. {proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, GeAttrValue::VT_LIST_DATA_TYPE},
  165. };
  166. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue::GeAttrValue() { value_.InitDefault(); }
  167. GeAttrValue::GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val) : value_(proto_owner, val) {}
  168. GeAttrValue::ValueType GeAttrValue::GetValueType() const {
  169. auto proto_msg = value_.GetProtoMsg();
  170. if (proto_msg != nullptr) {
  171. auto val_case = proto_msg->value_case();
  172. if (val_case != proto::AttrDef::kList) {
  173. auto it = GeAttrValueImp::attr_val_one_type_map_.find(val_case);
  174. if (it != GeAttrValueImp::attr_val_one_type_map_.end()) {
  175. return it->second;
  176. }
  177. } else {
  178. auto it = GeAttrValueImp::attr_val_list_type_map_.find(proto_msg->list().val_type());
  179. if (it != GeAttrValueImp::attr_val_list_type_map_.end()) {
  180. return it->second;
  181. }
  182. }
  183. }
  184. return GeAttrValue::VT_NONE;
  185. }
  186. bool GeAttrValue::IsEmpty() const { return GetValueType() == VT_NONE; }
  187. GeAttrValue GeAttrValue::Copy() const {
  188. GeAttrValue valueRet;
  189. auto proto_msg = value_.GetProtoMsg();
  190. auto proto_msg_ret = valueRet.value_.GetProtoMsg();
  191. if (proto_msg != nullptr && proto_msg_ret != nullptr) {
  192. *proto_msg_ret = *proto_msg;
  193. }
  194. return valueRet;
  195. }
  196. #define ATTR_VALUE_SET_GET_IMP(type) \
  197. graphStatus GeAttrValue::SetValue(const type &val) { \
  198. auto proto_msg = value_.GetProtoMsg(); \
  199. if (proto_msg) { \
  200. if (GeAttrValueImp::SetValue(*proto_msg, val)) { \
  201. return GRAPH_SUCCESS; \
  202. } \
  203. } \
  204. return GRAPH_FAILED; \
  205. } \
  206. \
  207. graphStatus GeAttrValue::GetValue(type &val) const { \
  208. auto proto_msg = value_.GetProtoMsg(); \
  209. if (proto_msg) { \
  210. if (GeAttrValueImp::GetValue(*proto_msg, value_.GetProtoOwner(), val)) { \
  211. return GRAPH_SUCCESS; \
  212. } \
  213. } \
  214. return GRAPH_FAILED; \
  215. }
  216. ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR)
  217. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>)
  218. ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT)
  219. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>)
  220. ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524
  221. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>)
  222. ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL)
  223. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>)
  224. ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR_DESC)
  225. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::TENSOR_DESC>)
  226. ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR)
  227. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::TENSOR>)
  228. ATTR_VALUE_SET_GET_IMP(GeAttrValue::GRAPH)
  229. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::GRAPH>)
  230. ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES)
  231. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>)
  232. ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS)
  233. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>)
  234. /*lint -e665*/
  235. ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>)
  236. /*lint +e665*/
  237. ATTR_VALUE_SET_GET_IMP(vector<DataType>) // lint !e665
  238. ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665
  239. #undef ATTR_VALUE_SET_GET_IMP
  240. graphStatus GeAttrValue::MutableTensor(GeTensorPtr &tensor) { return GetValue(tensor); }
  241. graphStatus GeAttrValue::MutableListTensor(vector<GeTensorPtr> &list_tensor) { return GetValue(list_tensor); }
  242. class AttrUtilsHelper {
  243. public:
  244. inline static bool GetValueCheckType(const proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) {
  245. if (attr_def.value_case() != proto_case) {
  246. GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case);
  247. return false;
  248. }
  249. return true;
  250. }
  251. inline static bool GetValueCheckListType(
  252. const proto::AttrDef &attr_def, proto::AttrDef_ListValue_ListValueType proto_list_case,
  253. const std::function<bool(const proto::AttrDef &proto_attr_val)> item_check_fun) {
  254. if (attr_def.value_case() != proto::AttrDef::kList) {
  255. GELOGW("Check ListType Failed, value_case %u", attr_def.value_case());
  256. return false;
  257. }
  258. auto &list = attr_def.list();
  259. if (list.val_type() == proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE) {
  260. return item_check_fun(attr_def);
  261. }
  262. if (list.val_type() != proto_list_case) {
  263. GELOGW("Check ListType Failed, val_type %u, expected %u", list.val_type(), proto_list_case);
  264. return false;
  265. }
  266. return true;
  267. }
  268. inline static bool SetValueCheckType(proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) {
  269. if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto_case) {
  270. GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case);
  271. return false;
  272. }
  273. return true;
  274. }
  275. inline static bool SetValueCheckAndSetListType(proto::AttrDef &attr_def,
  276. proto::AttrDef_ListValue_ListValueType proto_list_case) {
  277. if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto::AttrDef::kList) {
  278. GELOGW("AttrUtils::Check Type Failed, value_case %u", attr_def.value_case());
  279. return false;
  280. }
  281. auto list = attr_def.mutable_list();
  282. if (list == nullptr) {
  283. GELOGE(GRAPH_FAILED, "list is nullptr");
  284. return false;
  285. }
  286. if (list->val_type() != proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE &&
  287. list->val_type() != proto_list_case) {
  288. GELOGW("AttrUtils::Check ListType Type Failed, val_type %d, expected %d", static_cast<int>(list->val_type()),
  289. static_cast<int>(proto_list_case));
  290. return false;
  291. }
  292. list->set_val_type(proto_list_case);
  293. return true;
  294. }
  295. static bool GetAttrMapItem(const AttrHolder *obj, const string &name, const proto::AttrDef *&attr_def) {
  296. if (obj == nullptr) {
  297. GELOGE(FAILED, "%s obj is nullptr", name.c_str());
  298. return false;
  299. }
  300. auto attr_map = obj->GetAttrMap().GetProtoMsg();
  301. if (attr_map == nullptr) {
  302. GELOGE(FAILED, "%s attr map is nullptr", name.c_str());
  303. return false;
  304. }
  305. auto it = attr_map->find(name);
  306. if (it == attr_map->end()) {
  307. return false;
  308. }
  309. attr_def = &it->second;
  310. return true;
  311. }
  312. inline static bool MutableAttrMapItem(AttrHolder *obj, const string &name, proto::AttrDef *&attr_def) {
  313. if (obj == nullptr) {
  314. GELOGE(FAILED, " %s obj is nullptr", name.c_str());
  315. return false;
  316. }
  317. auto attr_map = obj->MutableAttrMap().GetProtoMsg();
  318. if (attr_map == nullptr) {
  319. GELOGE(FAILED, "%s attr map is nullptr", name.c_str());
  320. return false;
  321. }
  322. // Get or add
  323. attr_def = &((*attr_map)[name]);
  324. return true;
  325. }
  326. };
  327. #define ATTR_VALUE_IMP_SET_ONE(ValType, proto_case, protoItem) \
  328. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \
  329. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \
  330. return false; \
  331. } \
  332. proto_attr_val.set_##protoItem(value); \
  333. return true; \
  334. }
  335. #define ATTR_VALUE_IMP_SET_LIST(ValType, proto_list_case, protoItem) \
  336. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \
  337. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, \
  338. proto::AttrDef_ListValue_ListValueType_##proto_list_case)) { \
  339. return false; \
  340. } \
  341. auto list = proto_attr_val.mutable_list(); \
  342. list->clear_##protoItem(); \
  343. for (const auto &item : value) { \
  344. list->add_##protoItem(item); \
  345. } \
  346. return true; \
  347. }
  348. ATTR_VALUE_IMP_SET_ONE(int64_t, kI, i)
  349. ATTR_VALUE_IMP_SET_ONE(float, kF, f)
  350. ATTR_VALUE_IMP_SET_ONE(const string &, kS, s)
  351. ATTR_VALUE_IMP_SET_ONE(bool, kB, b)
  352. ATTR_VALUE_IMP_SET_LIST(const vector<int64_t> &, VT_LIST_INT, i)
  353. ATTR_VALUE_IMP_SET_LIST(const vector<int32_t> &, VT_LIST_INT, i)
  354. ATTR_VALUE_IMP_SET_LIST(const vector<uint32_t> &, VT_LIST_INT, i)
  355. ATTR_VALUE_IMP_SET_LIST(const vector<float> &, VT_LIST_FLOAT, f)
  356. ATTR_VALUE_IMP_SET_LIST(const vector<string> &, VT_LIST_STRING, s)
  357. ATTR_VALUE_IMP_SET_LIST(const vector<bool> &, VT_LIST_BOOL, b)
  358. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensorDesc &value) {
  359. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) {
  360. return false;
  361. }
  362. auto proto_msg = value.tensor_descriptor_.GetProtoMsg();
  363. if (proto_msg == nullptr) {
  364. return false;
  365. }
  366. *proto_attr_val.mutable_td() = *proto_msg;
  367. return true;
  368. }
  369. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeTensorDesc> &value) {
  370. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  371. proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC)) {
  372. return false;
  373. }
  374. auto list = proto_attr_val.mutable_list();
  375. GE_CHECK_NOTNULL_EXEC(list, return false);
  376. list->clear_td();
  377. for (const auto &item : value) {
  378. auto proto_msg = item.tensor_descriptor_.GetProtoMsg();
  379. if (proto_msg == nullptr) {
  380. proto_attr_val.clear_list();
  381. return false;
  382. }
  383. *list->add_td() = *proto_msg;
  384. }
  385. return true;
  386. }
  387. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ConstGeTensorPtr &value) {
  388. if (value) {
  389. return SetValue(proto_attr_val, *value);
  390. } else {
  391. return SetValue(proto_attr_val, GeTensor());
  392. }
  393. }
  394. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensor &val) {
  395. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kT)) {
  396. return false;
  397. }
  398. auto proto_msg = val.tensor_def_.GetProtoMsg();
  399. if (proto_msg == nullptr) {
  400. GELOGE(FAILED, "Proto msg is nullptr");
  401. return false;
  402. }
  403. *proto_attr_val.mutable_t() = *proto_msg;
  404. return true;
  405. }
  406. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeTensorPtr> &value) {
  407. vector<ConstGeTensorPtr> constList(value.size());
  408. std::copy(value.begin(), value.end(), constList.begin());
  409. return SetValue(proto_attr_val, constList);
  410. }
  411. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<ConstGeTensorPtr> &value) {
  412. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  413. proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) {
  414. return false;
  415. }
  416. auto list = proto_attr_val.mutable_list();
  417. GE_CHECK_NOTNULL_EXEC(list, return false);
  418. list->clear_t();
  419. for (const auto &item : value) {
  420. if (item == nullptr) {
  421. GELOGE(GRAPH_FAILED, "AttrUtils::SetListTensor item is nullptr");
  422. proto_attr_val.clear_list();
  423. return false;
  424. }
  425. auto proto_msg = item->tensor_def_.GetProtoMsg();
  426. if (proto_msg == nullptr) {
  427. GELOGE(FAILED, "Proto msg is nullptr");
  428. proto_attr_val.clear_list();
  429. return false;
  430. }
  431. *list->add_t() = *proto_msg;
  432. }
  433. return true;
  434. }
  435. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeTensor> &value) {
  436. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  437. proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) {
  438. return false;
  439. }
  440. auto list = proto_attr_val.mutable_list();
  441. GE_CHECK_NOTNULL_EXEC(list, return false);
  442. list->clear_t();
  443. for (const auto &item : value) {
  444. auto proto_msg = item.tensor_def_.GetProtoMsg();
  445. if (proto_msg == nullptr) {
  446. GELOGE(FAILED, "Proto msg is nullptr");
  447. proto_attr_val.clear_list();
  448. return false;
  449. }
  450. *list->add_t() = *proto_msg;
  451. }
  452. return true;
  453. }
  454. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::BYTES &value) {
  455. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
  456. return false;
  457. }
  458. size_t val_size = value.GetSize();
  459. proto_attr_val.set_bt(value.GetData(), val_size);
  460. return true;
  461. }
  462. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::BYTES> &value) {
  463. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  464. proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) {
  465. return false;
  466. }
  467. auto list = proto_attr_val.mutable_list();
  468. GE_CHECK_NOTNULL_EXEC(list, return false);
  469. list->clear_bt();
  470. for (const auto &item : value) {
  471. list->add_bt(item.GetData(), item.GetSize());
  472. }
  473. return true;
  474. }
  475. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) {
  476. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) {
  477. return false;
  478. }
  479. auto proto_msg = value.named_attrs_.GetProtoMsg();
  480. if (proto_msg == nullptr) {
  481. GELOGE(FAILED, "Proto msg is nullptr");
  482. return false;
  483. }
  484. *proto_attr_val.mutable_func() = *proto_msg;
  485. return true;
  486. }
  487. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::NAMED_ATTRS> &value) {
  488. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  489. proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) {
  490. return false;
  491. }
  492. auto list = proto_attr_val.mutable_list();
  493. GE_CHECK_NOTNULL_EXEC(list, return false);
  494. list->clear_na();
  495. for (const auto &item : value) {
  496. auto proto_msg = item.named_attrs_.GetProtoMsg();
  497. if (proto_msg == nullptr) {
  498. proto_attr_val.clear_list();
  499. return false;
  500. }
  501. *list->add_na() = *proto_msg;
  502. }
  503. return true;
  504. }
  505. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::ComputeGraphPtr &value) {
  506. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kG)) {
  507. return false;
  508. }
  509. ModelSerializeImp imp;
  510. if (!imp.SerializeGraph(value, proto_attr_val.mutable_g())) {
  511. GELOGE(GRAPH_FAILED, "AttrUtils::SetGraph SerializeGraph Failed");
  512. proto_attr_val.clear_g();
  513. return false;
  514. }
  515. return true;
  516. }
  517. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<ge::ComputeGraphPtr> &value) {
  518. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  519. proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH)) {
  520. return false;
  521. }
  522. auto list = proto_attr_val.mutable_list();
  523. GE_CHECK_NOTNULL_EXEC(list, return false);
  524. list->clear_g();
  525. ModelSerializeImp imp;
  526. for (const auto &item : value) {
  527. if (!imp.SerializeGraph(item, list->add_g())) {
  528. GELOGE(GRAPH_FAILED, "AttrUtils::SetListGraph SerializeGraph");
  529. proto_attr_val.clear_list();
  530. return false;
  531. }
  532. }
  533. return true;
  534. }
  535. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<vector<int64_t>> &value) {
  536. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) {
  537. return false;
  538. }
  539. proto_attr_val.clear_list_list_int();
  540. auto list_list_int = proto_attr_val.mutable_list_list_int();
  541. GE_CHECK_NOTNULL_EXEC(list_list_int, return false);
  542. for (auto &list_int : value) {
  543. auto list_item = list_list_int->add_list_list_i();
  544. GE_CHECK_NOTNULL_EXEC(list_item, return false);
  545. for (auto &int_item : list_int) {
  546. list_item->add_list_i(int_item);
  547. }
  548. }
  549. return true;
  550. }
  551. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<ge::DataType> &value) {
  552. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  553. proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE)) {
  554. return false;
  555. }
  556. auto list = proto_attr_val.mutable_list();
  557. GE_CHECK_NOTNULL_EXEC(list, return false);
  558. list->clear_dt();
  559. for (const auto &item : value) {
  560. list->add_dt(static_cast<int64_t>(item));
  561. }
  562. return true;
  563. }
  564. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::DataType &value) {
  565. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) {
  566. return false;
  567. }
  568. proto_attr_val.set_dt(static_cast<int64_t>(value));
  569. return true;
  570. }
  571. #define ATTR_VALUE_IMP_GET_ONE(ValType, proto_case, protoItem) \
  572. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ValType value) { \
  573. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \
  574. return false; \
  575. } \
  576. value = proto_attr_val.protoItem(); \
  577. return true; \
  578. }
  579. #define ListValueItemCheck(protoItem) \
  580. [](const proto::AttrDef &proto_attr_val) { return proto_attr_val.list().protoItem##_size() > 0; }
  581. #define ATTR_VALUE_IMP_GET_LIST(ValType, proto_list_case, protoItem) \
  582. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, vector<ValType> &value) { \
  583. value.clear(); \
  584. if (!AttrUtilsHelper::GetValueCheckListType( \
  585. proto_attr_val, proto::AttrDef_ListValue_ListValueType_##proto_list_case, ListValueItemCheck(protoItem))) { \
  586. return false; \
  587. } \
  588. auto &list = proto_attr_val.list(); \
  589. for (const auto &item : list.protoItem()) { \
  590. value.push_back(item); \
  591. } \
  592. return true; \
  593. }
  594. ATTR_VALUE_IMP_GET_ONE(int64_t &, kI, i)
  595. ATTR_VALUE_IMP_GET_ONE(float &, kF, f)
  596. ATTR_VALUE_IMP_GET_ONE(string &, kS, s)
  597. ATTR_VALUE_IMP_GET_ONE(bool &, kB, b)
  598. ATTR_VALUE_IMP_GET_LIST(int64_t, VT_LIST_INT, i)
  599. ATTR_VALUE_IMP_GET_LIST(float, VT_LIST_FLOAT, f)
  600. ATTR_VALUE_IMP_GET_LIST(string, VT_LIST_STRING, s)
  601. ATTR_VALUE_IMP_GET_LIST(bool, VT_LIST_BOOL, b)
  602. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeTensorDesc &value) {
  603. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) {
  604. return false;
  605. }
  606. auto proto_msg = value.tensor_descriptor_.GetProtoMsg();
  607. if (proto_msg == nullptr) {
  608. return false;
  609. }
  610. *proto_msg = proto_attr_val.td();
  611. return true;
  612. }
  613. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  614. vector<GeTensorDesc> &value) {
  615. if (!AttrUtilsHelper::GetValueCheckListType(
  616. proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, ListValueItemCheck(td))) {
  617. return false;
  618. }
  619. auto &list = proto_attr_val.list();
  620. for (const auto &item : list.td()) {
  621. value.emplace_back(GeTensorDesc());
  622. auto proto_msg = value.back().tensor_descriptor_.GetProtoMsg();
  623. if (proto_msg == nullptr) {
  624. return false;
  625. }
  626. *proto_msg = item;
  627. }
  628. return true;
  629. }
  630. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner,
  631. GeTensorPtr &value) {
  632. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kT)) {
  633. return false;
  634. }
  635. value = std::shared_ptr<GeTensor>(new (std::nothrow)
  636. GeTensor(proto_owner, const_cast<proto::AttrDef &>(proto_attr_val).mutable_t()));
  637. GE_CHK_BOOL_RET_STATUS(value != nullptr, false, "value is nullptr");
  638. return true;
  639. }
  640. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner,
  641. vector<GeTensorPtr> &value) {
  642. value.clear();
  643. if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR,
  644. ListValueItemCheck(t))) {
  645. return false;
  646. }
  647. auto list = const_cast<proto::AttrDef &>(proto_attr_val).mutable_list();
  648. GE_CHECK_NOTNULL_EXEC(list, return false);
  649. for (auto &item : *(list->mutable_t())) {
  650. std::shared_ptr<GeTensor> temp_value = std::shared_ptr<GeTensor>(new (std::nothrow) GeTensor(proto_owner, &item));
  651. GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
  652. value.push_back(temp_value);
  653. }
  654. return true;
  655. }
  656. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeAttrValue::BYTES &value) {
  657. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
  658. return false;
  659. }
  660. auto &proto_val = proto_attr_val.bt();
  661. GE_LOGI_IF(proto_val.size() == 0, "size res is 0.");
  662. value = Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(proto_val.data()), proto_val.size());
  663. return true;
  664. }
  665. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  666. vector<GeAttrValue::BYTES> &value) {
  667. value.clear();
  668. if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES,
  669. ListValueItemCheck(bt))) {
  670. return false;
  671. }
  672. auto &list = proto_attr_val.list();
  673. for (const auto &item : list.bt()) {
  674. value.push_back(Buffer::CopyFrom((const uint8_t *)item.data(), item.size()));
  675. }
  676. return true;
  677. }
  678. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  679. GeAttrValue::NAMED_ATTRS &value) {
  680. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) {
  681. return false;
  682. }
  683. auto proto_msg = value.named_attrs_.GetProtoMsg();
  684. if (proto_msg == nullptr) {
  685. return false;
  686. }
  687. *proto_msg = proto_attr_val.func();
  688. return true;
  689. }
  690. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  691. vector<GeAttrValue::NAMED_ATTRS> &value) {
  692. value.clear();
  693. if (!AttrUtilsHelper::GetValueCheckListType(
  694. proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) {
  695. return false;
  696. }
  697. auto &list = proto_attr_val.list();
  698. for (const auto &item : list.na()) {
  699. value.emplace_back(GeAttrValue::NAMED_ATTRS());
  700. if (value.empty()) {
  701. return false;
  702. }
  703. auto proto_msg = value.back().named_attrs_.GetProtoMsg();
  704. if (proto_msg == nullptr) {
  705. return false;
  706. }
  707. *proto_msg = item;
  708. }
  709. return true;
  710. }
  711. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ComputeGraphPtr &value) {
  712. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kG)) {
  713. return false;
  714. }
  715. ComputeGraphPtr graph = nullptr;
  716. std::shared_ptr<proto::GraphDef> graph_def;
  717. graph_def = ComGraphMakeShared<proto::GraphDef>(proto_attr_val.g());
  718. if (graph_def == nullptr) {
  719. GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
  720. graph_def = nullptr;
  721. return false; // lint !e665
  722. } else {
  723. ModelSerializeImp imp;
  724. imp.SetProtobufOwner(graph_def);
  725. if (!imp.UnserializeGraph(graph, *graph_def)) {
  726. GELOGE(GRAPH_FAILED, "UnserializeGraph Failed");
  727. return false;
  728. } // lint !e514
  729. value = graph;
  730. }
  731. return true;
  732. }
  733. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  734. vector<ComputeGraphPtr> &value) {
  735. value.clear();
  736. if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH,
  737. ListValueItemCheck(g))) {
  738. return false;
  739. }
  740. auto &list = proto_attr_val.list();
  741. for (const auto &item : list.g()) {
  742. std::shared_ptr<proto::GraphDef> graph_def;
  743. graph_def = ComGraphMakeShared<proto::GraphDef>(item);
  744. if (graph_def == nullptr) {
  745. GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
  746. graph_def = nullptr;
  747. return false; // lint !e665
  748. } else {
  749. ComputeGraphPtr graph = nullptr;
  750. ModelSerializeImp imp;
  751. imp.SetProtobufOwner(graph_def);
  752. if (!imp.UnserializeGraph(graph, *graph_def)) {
  753. GELOGE(GRAPH_FAILED, "UnserializeGraph Failed");
  754. return false;
  755. } // lint !e514
  756. value.push_back(graph);
  757. }
  758. }
  759. return true;
  760. }
  761. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  762. vector<vector<int64_t>> &value) {
  763. value.clear();
  764. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) {
  765. return false;
  766. }
  767. auto &list_listint = proto_attr_val.list_list_int().list_list_i();
  768. for (auto &list_int : list_listint) {
  769. vector<int64_t> list_item(list_int.list_i().size());
  770. if (!list_int.list_i().empty()) {
  771. (void)std::copy(list_int.list_i().begin(), list_int.list_i().end(), list_item.begin());
  772. }
  773. value.push_back(list_item);
  774. }
  775. return true;
  776. }
  777. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  778. vector<ge::DataType> &value) {
  779. if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE,
  780. ListValueItemCheck(dt))) {
  781. return false;
  782. }
  783. auto &list = proto_attr_val.list();
  784. for (const auto &item : list.dt()) {
  785. value.emplace_back(static_cast<ge::DataType>(item));
  786. }
  787. return true;
  788. }
  789. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ge::DataType &value) {
  790. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) {
  791. return false;
  792. }
  793. value = static_cast<ge::DataType>(proto_attr_val.dt());
  794. return true;
  795. }
  796. GE_FUNC_HOST_VISIBILITY bool GeAttrValueImp::SetZeroCopyBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  797. Buffer &&buffer) {
  798. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
  799. return false;
  800. }
  801. auto proto_msg = buffer.data_.GetProtoMsg();
  802. if (proto_msg == nullptr) {
  803. return false;
  804. }
  805. proto_attr_val.set_bt(std::move(*proto_msg->mutable_bt()));
  806. return true;
  807. }
  808. bool GeAttrValueImp::GetZeroCopyBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner,
  809. Buffer &buffer) {
  810. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
  811. return false;
  812. }
  813. buffer = Buffer(proto_owner, &const_cast<proto::AttrDef &>(proto_attr_val));
  814. return true;
  815. }
  816. bool GeAttrValueImp::SetZeroCopyListBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  817. vector<Buffer> &list_buffer) {
  818. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  819. proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) {
  820. return false;
  821. }
  822. auto list = proto_attr_val.mutable_list();
  823. GE_CHECK_NOTNULL_EXEC(list, return false);
  824. list->clear_bt();
  825. for (auto &item : list_buffer) {
  826. auto proto_msg = item.data_.GetProtoMsg();
  827. if (proto_msg == nullptr) {
  828. return false;
  829. }
  830. list->add_bt(std::move(*proto_msg->mutable_bt()));
  831. }
  832. return true;
  833. }
  834. bool GeAttrValueImp::GetZeroCopyListBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner,
  835. vector<Buffer> &list_buffer) {
  836. list_buffer.clear();
  837. if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES,
  838. ListValueItemCheck(bt))) {
  839. return false;
  840. }
  841. auto list = const_cast<proto::AttrDef &>(proto_attr_val).mutable_list();
  842. GE_CHECK_NOTNULL_EXEC(list, return false);
  843. for (auto &item : *(list->mutable_bt())) {
  844. list_buffer.emplace_back(Buffer(proto_owner, &item));
  845. }
  846. return true;
  847. }
  848. bool AttrUtils::HasAttr(ConstAttrHolderAdapter &&obj, const string &name) {
  849. if (!obj) {
  850. return false;
  851. }
  852. return obj->HasAttr(name);
  853. }
  854. #define ATTR_UTILS_SET_IMP(FuncName, Type) \
  855. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Set##FuncName( \
  856. AttrHolderAdapter &&obj, const string &name, const Type &value) { \
  857. proto::AttrDef *proto_attr_val = nullptr; \
  858. if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \
  859. return false; \
  860. } \
  861. if (!GeAttrValueImp::SetValue(*proto_attr_val, value)) { \
  862. GELOGW("Set" #FuncName " failed key %s", name.c_str()); \
  863. return false; \
  864. } \
  865. return true; \
  866. }
  867. #define ATTR_UTILS_GET_IMP(FuncName, Type) \
  868. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Get##FuncName(ConstAttrHolderAdapter &&obj, \
  869. const string &name, Type &value) { \
  870. const proto::AttrDef *proto_attr_val = nullptr; \
  871. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \
  872. return false; \
  873. } \
  874. if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value)) { \
  875. GELOGW("Get" #FuncName " failed key %s", name.c_str()); \
  876. return false; \
  877. } \
  878. return true; \
  879. }
  880. #define ATTR_UTILS_SET_GET_IMP(FuncName, Type) \
  881. ATTR_UTILS_SET_IMP(FuncName, Type) \
  882. ATTR_UTILS_GET_IMP(FuncName, Type)
  883. ATTR_UTILS_SET_GET_IMP(Int, int64_t)
  884. ATTR_UTILS_SET_GET_IMP(Float, float)
  885. ATTR_UTILS_SET_GET_IMP(Bool, bool)
  886. ATTR_UTILS_SET_GET_IMP(Str, string)
  887. ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc)
  888. ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr)
  889. ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr)
  890. ATTR_UTILS_SET_IMP(Tensor, GeTensor)
  891. ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS)
  892. ATTR_UTILS_SET_GET_IMP(Bytes, Buffer)
  893. ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr)
  894. /*lint -e665*/
  895. ATTR_UTILS_SET_GET_IMP(ListListInt, vector<vector<int64_t>>)
  896. /*lint +e665*/
  897. ATTR_UTILS_SET_GET_IMP(ListInt, vector<int64_t>)
  898. ATTR_UTILS_SET_IMP(ListInt, vector<int32_t>)
  899. ATTR_UTILS_SET_IMP(ListInt, vector<uint32_t>)
  900. ATTR_UTILS_SET_GET_IMP(ListFloat, vector<float>)
  901. ATTR_UTILS_SET_GET_IMP(ListBool, vector<bool>)
  902. ATTR_UTILS_SET_GET_IMP(ListStr, vector<string>)
  903. ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector<GeTensorDesc>)
  904. ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensorPtr>)
  905. ATTR_UTILS_SET_IMP(ListTensor, vector<ConstGeTensorPtr>)
  906. ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>)
  907. ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>)
  908. ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>)
  909. ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>)
  910. ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) // lint !e665
  911. ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665
  912. bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name,
  913. std::initializer_list<ConstGeTensorPtr> &&value) {
  914. return SetListTensor(std::move(obj), name, vector<ConstGeTensorPtr>(value));
  915. }
  916. bool AttrUtils::GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value) {
  917. const proto::AttrDef *proto_attr_val = nullptr;
  918. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  919. return false;
  920. }
  921. GeTensorPtr tensor;
  922. if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) {
  923. return false;
  924. }
  925. value = tensor;
  926. return true;
  927. }
  928. bool AttrUtils::GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector<ConstGeTensorPtr> &value) {
  929. value.clear();
  930. const proto::AttrDef *proto_attr_val = nullptr;
  931. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  932. return false;
  933. }
  934. vector<GeTensorPtr> tensor;
  935. if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) {
  936. return false;
  937. }
  938. value.insert(value.begin(), tensor.begin(), tensor.end());
  939. return true;
  940. }
  941. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::MutableTensor(AttrHolderAdapter &&obj,
  942. const string &name, GeTensorPtr &value) {
  943. const proto::AttrDef *proto_attr_val = nullptr;
  944. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  945. return false;
  946. }
  947. return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value);
  948. }
  949. bool AttrUtils::MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector<GeTensorPtr> &value) {
  950. value.clear();
  951. const proto::AttrDef *proto_attr_val = nullptr;
  952. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  953. return false;
  954. }
  955. return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value);
  956. }
  957. bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list<int64_t> &&value) {
  958. proto::AttrDef *proto_attr_val = nullptr;
  959. if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  960. return false;
  961. }
  962. return GeAttrValueImp::SetValue(*proto_attr_val, value);
  963. }
  964. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name,
  965. int32_t &value) {
  966. int64_t int64_val = 0;
  967. if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) {
  968. return false;
  969. }
  970. if (int64_val > INT32_MAX) {
  971. GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to int32_t", int64_val);
  972. return false;
  973. }
  974. value = static_cast<int32_t>(int64_val);
  975. return true;
  976. }
  977. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name,
  978. uint32_t &value) {
  979. int64_t int64_val = 0;
  980. if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) {
  981. return false;
  982. }
  983. if (int64_val > UINT32_MAX) {
  984. GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to uint32_t", int64_val);
  985. return false;
  986. }
  987. value = static_cast<uint32_t>(int64_val);
  988. return true;
  989. }
  990. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj,
  991. const string &name, vector<int32_t> &value) {
  992. value.clear();
  993. vector<int64_t> int64_list;
  994. if (!GetListInt(std::move(obj), name, int64_list)) {
  995. return false;
  996. }
  997. for (size_t i = 0; i < int64_list.size(); ++i) {
  998. if (int64_list[i] > INT32_MAX) {
  999. GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]);
  1000. return false;
  1001. }
  1002. }
  1003. value.insert(value.begin(), int64_list.begin(), int64_list.end());
  1004. return true;
  1005. }
  1006. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj,
  1007. const string &name, vector<uint32_t> &value) {
  1008. value.clear();
  1009. vector<int64_t> int64_list;
  1010. if (!GetListInt(std::move(obj), name, int64_list)) {
  1011. return false;
  1012. }
  1013. for (size_t i = 0; i < int64_list.size(); ++i) {
  1014. if (int64_list[i] > UINT32_MAX) {
  1015. GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]);
  1016. return false;
  1017. }
  1018. }
  1019. value.insert(value.begin(), int64_list.begin(), int64_list.end());
  1020. return true;
  1021. }
  1022. bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<ConstOpDescPtr> &value) {
  1023. if (obj) {
  1024. vector<Buffer> bytes_vals;
  1025. for (auto &item : value) {
  1026. ModelSerialize serialize;
  1027. auto buffer = serialize.SerializeOpDesc(item);
  1028. if (buffer.GetSize() == 0) {
  1029. return false;
  1030. }
  1031. bytes_vals.push_back(buffer);
  1032. }
  1033. return SetZeroCopyListBytes(std::move(obj), name, bytes_vals);
  1034. }
  1035. return false;
  1036. }
  1037. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj,
  1038. const string &name,
  1039. const vector<OpDescPtr> &value) {
  1040. if (obj) {
  1041. vector<Buffer> bytes_vals;
  1042. for (auto &item : value) {
  1043. ModelSerialize serialize;
  1044. auto buffer = serialize.SerializeOpDesc(item);
  1045. if (buffer.GetSize() == 0) {
  1046. return false;
  1047. }
  1048. bytes_vals.push_back(buffer);
  1049. }
  1050. return SetZeroCopyListBytes(std::move(obj), name, bytes_vals);
  1051. }
  1052. return false;
  1053. }
  1054. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(ConstAttrHolderAdapter &&obj,
  1055. const string &name,
  1056. vector<OpDescPtr> &value) {
  1057. value.clear();
  1058. vector<Buffer> bytes_vals;
  1059. if (!GetZeroCopyListBytes(std::move(obj), name, bytes_vals)) {
  1060. return false;
  1061. }
  1062. for (const auto &item : bytes_vals) {
  1063. ModelSerialize serialize;
  1064. auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732
  1065. value.push_back(op_desc);
  1066. }
  1067. return true;
  1068. }
  1069. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetZeroCopyBytes(AttrHolderAdapter &&obj,
  1070. const string &name, Buffer &&buffer) {
  1071. // Value will be moved
  1072. proto::AttrDef *proto_attr_val = nullptr;
  1073. if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  1074. return false;
  1075. }
  1076. return GeAttrValueImp::SetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), std::move(buffer));
  1077. }
  1078. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetZeroCopyBytes(ConstAttrHolderAdapter &&obj,
  1079. const string &name, Buffer &buffer) {
  1080. const proto::AttrDef *proto_attr_val = nullptr;
  1081. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  1082. return false;
  1083. }
  1084. return GeAttrValueImp::GetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), buffer);
  1085. }
  1086. bool AttrUtils::SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector<Buffer> &list_buffer) {
  1087. // Value will be moved
  1088. proto::AttrDef *proto_attr_val = nullptr;
  1089. if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  1090. return false;
  1091. }
  1092. return GeAttrValueImp::SetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer);
  1093. }
  1094. bool AttrUtils::GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &list_buffer) {
  1095. list_buffer.clear();
  1096. const proto::AttrDef *proto_attr_val = nullptr;
  1097. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  1098. return false;
  1099. }
  1100. return GeAttrValueImp::GetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer);
  1101. }
  1102. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) {
  1103. if (org_op_desc == nullptr) {
  1104. GELOGE(GRAPH_FAILED, "org_op_desc is null");
  1105. return nullptr;
  1106. }
  1107. std::shared_ptr<proto::OpDef> op_def;
  1108. op_def = ComGraphMakeShared<proto::OpDef>();
  1109. if (op_def == nullptr) {
  1110. GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
  1111. return nullptr; // lint !e665
  1112. }
  1113. ModelSerializeImp imp;
  1114. (void)imp.SerializeOpDesc(org_op_desc, op_def.get());
  1115. imp.SetProtobufOwner(op_def);
  1116. OpDescPtr op_desc = nullptr;
  1117. GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed");
  1118. op_desc->extAttrs_ = org_op_desc->extAttrs_;
  1119. // This function may be called by some passes of fusion engine, in this condition, do not need these attribute
  1120. if (!op_desc->input_name_idx_.empty()) {
  1121. op_desc->input_name_idx_.clear();
  1122. }
  1123. if (!op_desc->output_name_idx_.empty()) {
  1124. op_desc->output_name_idx_.clear();
  1125. }
  1126. if (!op_desc->optional_input_names_.empty()) {
  1127. op_desc->optional_input_names_.clear();
  1128. }
  1129. return op_desc;
  1130. }
  1131. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) {
  1132. if (org_op_desc == nullptr) {
  1133. GELOGE(GRAPH_FAILED, "org_op_desc is null");
  1134. return nullptr;
  1135. }
  1136. std::shared_ptr<proto::OpDef> op_def = ComGraphMakeShared<proto::OpDef>();
  1137. if (op_def == nullptr) {
  1138. GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
  1139. return nullptr;
  1140. }
  1141. ModelSerializeImp imp;
  1142. (void)imp.SerializeOpDesc(org_op_desc, op_def.get());
  1143. imp.SetProtobufOwner(op_def);
  1144. OpDescPtr op_desc = nullptr;
  1145. GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed");
  1146. op_desc->extAttrs_ = org_op_desc->extAttrs_;
  1147. op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end());
  1148. op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(),
  1149. org_op_desc->optional_input_names_.end());
  1150. op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end());
  1151. op_desc->infer_func_ = org_op_desc->infer_func_;
  1152. op_desc->infer_format_func_ = org_op_desc->infer_format_func_;
  1153. op_desc->verifier_func_ = org_op_desc->verifier_func_;
  1154. return op_desc;
  1155. }
  1156. std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) {
  1157. auto holder = obj.get();
  1158. if (holder == nullptr) {
  1159. return "";
  1160. }
  1161. auto attrs_map = holder->GetAttrMap();
  1162. if (attrs_map.GetProtoMsg() == nullptr) {
  1163. return "";
  1164. }
  1165. std::map<std::string, std::string> ordered_attrs;
  1166. for (auto &attr : *(attrs_map.GetProtoMsg())) {
  1167. ordered_attrs[attr.first] = attr.second.SerializeAsString();
  1168. }
  1169. std::stringstream ss;
  1170. for (auto &attr : ordered_attrs) {
  1171. ss << attr.first << ":" << attr.second << ";";
  1172. }
  1173. return ss.str();
  1174. }
  1175. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示