You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_attr_value.cc 57 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/ge_attr_value.h"
  17. #include "graph/ge_tensor.h"
  18. #include "external/graph/graph.h"
  19. #include "utils/attr_utils.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "graph/model_serialize.h"
  22. #include "proto/ge_ir.pb.h"
  23. #include "detail/model_serialize_imp.h"
  24. #include "debug/ge_attr_define.h"
  25. #include "debug/ge_log.h"
  26. #include "debug/ge_util.h"
  27. using std::map;
  28. using std::string;
  29. using std::vector;
  30. namespace ge {
  31. NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); }
  32. NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg)
  33. : named_attrs_(owner, proto_msg) {} // lint !e1744
  34. void NamedAttrs::SetName(const std::string &name) {
  35. auto proto_msg = named_attrs_.GetProtoMsg();
  36. if (proto_msg != nullptr) {
  37. proto_msg->set_name(name);
  38. }
  39. }
  40. string NamedAttrs::GetName() const {
  41. auto proto_msg = named_attrs_.GetProtoMsg();
  42. if (proto_msg != nullptr) {
  43. return proto_msg->name();
  44. }
  45. return string();
  46. }
  47. GeAttrValue NamedAttrs::GetItem(const string &key) const {
  48. GeAttrValue value;
  49. (void)GetAttr(key, value);
  50. return value;
  51. }
  52. ProtoAttrMapHelper NamedAttrs::MutableAttrMap() {
  53. auto proto_msg = named_attrs_.GetProtoMsg();
  54. if (proto_msg != nullptr) {
  55. return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr());
  56. }
  57. return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr);
  58. }
  59. ConstProtoAttrMapHelper NamedAttrs::GetAttrMap() const {
  60. auto proto_msg = named_attrs_.GetProtoMsg();
  61. if (proto_msg != nullptr) {
  62. return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr());
  63. }
  64. return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr);
  65. }
  66. class GeAttrValueImp {
  67. public:
  68. static map<proto::AttrDef::ValueCase, GeAttrValue::ValueType> attr_val_one_type_map_;
  69. static map<proto::AttrDef_ListValue_ListValueType, GeAttrValue::ValueType> attr_val_list_type_map_;
  70. static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::INT val);
  71. static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::FLOAT val);
  72. static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::BOOL val);
  73. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::STR &val);
  74. static bool SetValue(proto::AttrDef &attr_def, const ConstGeTensorPtr &val);
  75. static bool SetValue(proto::AttrDef &attr_def, const GeTensor &val);
  76. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::TENSOR_DESC &val);
  77. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::BYTES &val);
  78. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::NAMED_ATTRS &val);
  79. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::GRAPH &val);
  80. static bool SetValue(proto::AttrDef &attr_def, const vector<int64_t> &val);
  81. static bool SetValue(proto::AttrDef &attr_def, const vector<int32_t> &val);
  82. static bool SetValue(proto::AttrDef &attr_def, const vector<uint32_t> &val);
  83. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_FLOAT &val);
  84. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BOOL &val);
  85. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_STR &val);
  86. static bool SetValue(proto::AttrDef &proto_attr_val, const vector<GeTensorPtr> &value);
  87. static bool SetValue(proto::AttrDef &proto_attr_val, const vector<ConstGeTensorPtr> &value);
  88. static bool SetValue(proto::AttrDef &attr_def, const vector<GeTensor> &val);
  89. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_TENSOR_DESC &val);
  90. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BYTES &val);
  91. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_NAMED_ATTRS &val);
  92. static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_GRAPH &val);
  93. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::INT &val);
  94. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::FLOAT &val);
  95. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BOOL &val);
  96. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::STR &val);
  97. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::TENSOR &val);
  98. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensor &val);
  99. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  100. GeAttrValue::TENSOR_DESC &val);
  101. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BYTES &val);
  102. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  103. GeAttrValue::NAMED_ATTRS &val);
  104. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::GRAPH &val);
  105. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  106. GeAttrValue::LIST_INT &val);
  107. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  108. GeAttrValue::LIST_FLOAT &val);
  109. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  110. GeAttrValue::LIST_BOOL &val);
  111. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  112. GeAttrValue::LIST_STR &val);
  113. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  114. GeAttrValue::LIST_TENSOR &val);
  115. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, vector<GeTensor> &val);
  116. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  117. GeAttrValue::LIST_TENSOR_DESC &val);
  118. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  119. GeAttrValue::LIST_BYTES &val);
  120. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  121. GeAttrValue::LIST_NAMED_ATTRS &val);
  122. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  123. GeAttrValue::LIST_GRAPH &val);
  124. // Value will be moved
  125. static bool SetZeroCopyBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &&buffer);
  126. static bool GetZeroCopyBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &buffer);
  127. // Value will be moved
  128. static bool SetZeroCopyListBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  129. vector<Buffer> &list_buffer);
  130. static bool GetZeroCopyListBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  131. vector<Buffer> &list_buffer);
  132. static bool SetValue(proto::AttrDef &attr_def, const vector<vector<int64_t>> &value);
  133. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  134. vector<vector<int64_t>> &value);
  135. static bool SetValue(proto::AttrDef &attr_def, const vector<ge::DataType> &value);
  136. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner,
  137. vector<ge::DataType> &value);
  138. static bool SetValue(proto::AttrDef &attr_def, const ge::DataType &value);
  139. static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ge::DataType &value);
  140. };
  141. map<proto::AttrDef::ValueCase, GeAttrValue::ValueType> GeAttrValueImp::attr_val_one_type_map_ = {
  142. {proto::AttrDef::kI, GeAttrValue::VT_INT},
  143. {proto::AttrDef::kF, GeAttrValue::VT_FLOAT},
  144. {proto::AttrDef::kB, GeAttrValue::VT_BOOL},
  145. {proto::AttrDef::kS, GeAttrValue::VT_STRING},
  146. {proto::AttrDef::kT, GeAttrValue::VT_TENSOR},
  147. {proto::AttrDef::kTd, GeAttrValue::VT_TENSOR_DESC},
  148. {proto::AttrDef::kG, GeAttrValue::VT_GRAPH},
  149. {proto::AttrDef::kBt, GeAttrValue::VT_BYTES},
  150. {proto::AttrDef::kFunc, GeAttrValue::VT_NAMED_ATTRS},
  151. {proto::AttrDef::kListListInt, GeAttrValue::VT_LIST_LIST_INT},
  152. {proto::AttrDef::kDt, GeAttrValue::VT_DATA_TYPE},
  153. };
  154. map<proto::AttrDef_ListValue_ListValueType, GeAttrValue::ValueType> GeAttrValueImp::attr_val_list_type_map_ = {
  155. {proto::AttrDef_ListValue_ListValueType_VT_LIST_INT, GeAttrValue::VT_LIST_INT},
  156. {proto::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT, GeAttrValue::VT_LIST_FLOAT},
  157. {proto::AttrDef_ListValue_ListValueType_VT_LIST_BOOL, GeAttrValue::VT_LIST_BOOL},
  158. {proto::AttrDef_ListValue_ListValueType_VT_LIST_STRING, GeAttrValue::VT_LIST_STRING},
  159. {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, GeAttrValue::VT_LIST_TENSOR},
  160. {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, GeAttrValue::VT_LIST_TENSOR_DESC},
  161. {proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, GeAttrValue::VT_LIST_GRAPH},
  162. {proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, GeAttrValue::VT_LIST_BYTES},
  163. {proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, GeAttrValue::VT_LIST_NAMED_ATTRS},
  164. {proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, GeAttrValue::VT_LIST_DATA_TYPE},
  165. };
  166. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue::GeAttrValue() { value_.InitDefault(); }
  167. GeAttrValue::GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val) : value_(proto_owner, val) {}
  168. GeAttrValue::ValueType GeAttrValue::GetValueType() const {
  169. auto proto_msg = value_.GetProtoMsg();
  170. if (proto_msg != nullptr) {
  171. auto val_case = proto_msg->value_case();
  172. if (val_case != proto::AttrDef::kList) {
  173. auto it = GeAttrValueImp::attr_val_one_type_map_.find(val_case);
  174. if (it != GeAttrValueImp::attr_val_one_type_map_.end()) {
  175. return it->second;
  176. }
  177. } else {
  178. auto it = GeAttrValueImp::attr_val_list_type_map_.find(proto_msg->list().val_type());
  179. if (it != GeAttrValueImp::attr_val_list_type_map_.end()) {
  180. return it->second;
  181. }
  182. }
  183. }
  184. return GeAttrValue::VT_NONE;
  185. }
  186. bool GeAttrValue::IsEmpty() const { return GetValueType() == VT_NONE; }
  187. GeAttrValue GeAttrValue::Copy() const {
  188. GeAttrValue valueRet;
  189. auto proto_msg = value_.GetProtoMsg();
  190. auto proto_msg_ret = valueRet.value_.GetProtoMsg();
  191. if (proto_msg != nullptr && proto_msg_ret != nullptr) {
  192. *proto_msg_ret = *proto_msg;
  193. }
  194. return valueRet;
  195. }
  196. #define ATTR_VALUE_SET_GET_IMP(type) \
  197. graphStatus GeAttrValue::SetValue(const type &val) { \
  198. auto proto_msg = value_.GetProtoMsg(); \
  199. if (proto_msg) { \
  200. if (GeAttrValueImp::SetValue(*proto_msg, val)) { \
  201. return GRAPH_SUCCESS; \
  202. } \
  203. } \
  204. return GRAPH_FAILED; \
  205. } \
  206. \
  207. graphStatus GeAttrValue::GetValue(type &val) const { \
  208. auto proto_msg = value_.GetProtoMsg(); \
  209. if (proto_msg) { \
  210. if (GeAttrValueImp::GetValue(*proto_msg, value_.GetProtoOwner(), val)) { \
  211. return GRAPH_SUCCESS; \
  212. } \
  213. } \
  214. return GRAPH_FAILED; \
  215. }
  216. ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR)
  217. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>)
  218. ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT)
  219. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>)
  220. ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524
  221. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>)
  222. ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL)
  223. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>)
  224. ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR_DESC)
  225. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::TENSOR_DESC>)
  226. ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR)
  227. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::TENSOR>)
  228. ATTR_VALUE_SET_GET_IMP(GeAttrValue::GRAPH)
  229. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::GRAPH>)
  230. ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES)
  231. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>)
  232. ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS)
  233. ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>)
  234. /*lint -e665*/
  235. ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>)
  236. /*lint +e665*/
  237. ATTR_VALUE_SET_GET_IMP(vector<DataType>) // lint !e665
  238. ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665
  239. #undef ATTR_VALUE_SET_GET_IMP
  240. graphStatus GeAttrValue::MutableTensor(GeTensorPtr &tensor) { return GetValue(tensor); }
  241. graphStatus GeAttrValue::MutableListTensor(vector<GeTensorPtr> &list_tensor) { return GetValue(list_tensor); }
  242. class AttrUtilsHelper {
  243. public:
  244. inline static bool GetValueCheckType(const proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) {
  245. if (attr_def.value_case() != proto_case) {
  246. GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case);
  247. return false;
  248. }
  249. return true;
  250. }
  251. inline static bool GetValueCheckListType(
  252. const proto::AttrDef &attr_def, proto::AttrDef_ListValue_ListValueType proto_list_case,
  253. const std::function<bool(const proto::AttrDef &proto_attr_val)> item_check_fun) {
  254. if (attr_def.value_case() != proto::AttrDef::kList) {
  255. GELOGW("Check ListType Failed, value_case %u", attr_def.value_case());
  256. return false;
  257. }
  258. auto &list = attr_def.list();
  259. if (list.val_type() == proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE) {
  260. return item_check_fun(attr_def);
  261. }
  262. if (list.val_type() != proto_list_case) {
  263. GELOGW("Check ListType Failed, val_type %u, expected %u", list.val_type(), proto_list_case);
  264. return false;
  265. }
  266. return true;
  267. }
  268. inline static bool SetValueCheckType(proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) {
  269. if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto_case) {
  270. GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case);
  271. return false;
  272. }
  273. return true;
  274. }
  275. inline static bool SetValueCheckAndSetListType(proto::AttrDef &attr_def,
  276. proto::AttrDef_ListValue_ListValueType proto_list_case) {
  277. if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto::AttrDef::kList) {
  278. GELOGW("AttrUtils::Check Type Failed, value_case %u", attr_def.value_case());
  279. return false;
  280. }
  281. auto list = attr_def.mutable_list();
  282. if (list == nullptr) {
  283. GELOGE(GRAPH_FAILED, "list is nullptr");
  284. return false;
  285. }
  286. if (list->val_type() != proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE &&
  287. list->val_type() != proto_list_case) {
  288. GELOGW("AttrUtils::Check ListType Type Failed, val_type %d, expected %d", static_cast<int>(list->val_type()),
  289. static_cast<int>(proto_list_case));
  290. return false;
  291. }
  292. list->set_val_type(proto_list_case);
  293. return true;
  294. }
  295. static bool GetAttrMapItem(const AttrHolder *obj, const string &name, const proto::AttrDef *&attr_def) {
  296. if (obj == nullptr) {
  297. GELOGE(FAILED, "%s obj is nullptr", name.c_str());
  298. return false;
  299. }
  300. auto attr_map = obj->GetAttrMap().GetProtoMsg();
  301. if (attr_map == nullptr) {
  302. GELOGE(FAILED, "%s attr map is nullptr", name.c_str());
  303. return false;
  304. }
  305. auto it = attr_map->find(name);
  306. if (it == attr_map->end()) {
  307. return false;
  308. }
  309. attr_def = &it->second;
  310. return true;
  311. }
  312. inline static bool MutableAttrMapItem(AttrHolder *obj, const string &name, proto::AttrDef *&attr_def) {
  313. if (obj == nullptr) {
  314. GELOGE(FAILED, " %s obj is nullptr", name.c_str());
  315. return false;
  316. }
  317. auto attr_map = obj->MutableAttrMap().GetProtoMsg();
  318. if (attr_map == nullptr) {
  319. GELOGE(FAILED, "%s attr map is nullptr", name.c_str());
  320. return false;
  321. }
  322. // Get or add
  323. attr_def = &((*attr_map)[name]);
  324. return true;
  325. }
  326. };
  327. #define ATTR_VALUE_IMP_SET_ONE(ValType, proto_case, protoItem) \
  328. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \
  329. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \
  330. return false; \
  331. } \
  332. proto_attr_val.set_##protoItem(value); \
  333. return true; \
  334. }
  335. #define ATTR_VALUE_IMP_SET_LIST(ValType, proto_list_case, protoItem) \
  336. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \
  337. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, \
  338. proto::AttrDef_ListValue_ListValueType_##proto_list_case)) { \
  339. return false; \
  340. } \
  341. auto list = proto_attr_val.mutable_list(); \
  342. list->clear_##protoItem(); \
  343. for (const auto &item : value) { \
  344. list->add_##protoItem(item); \
  345. } \
  346. return true; \
  347. }
  348. ATTR_VALUE_IMP_SET_ONE(int64_t, kI, i)
  349. ATTR_VALUE_IMP_SET_ONE(float, kF, f)
  350. ATTR_VALUE_IMP_SET_ONE(const string &, kS, s)
  351. ATTR_VALUE_IMP_SET_ONE(bool, kB, b)
  352. ATTR_VALUE_IMP_SET_LIST(const vector<int64_t> &, VT_LIST_INT, i)
  353. ATTR_VALUE_IMP_SET_LIST(const vector<int32_t> &, VT_LIST_INT, i)
  354. ATTR_VALUE_IMP_SET_LIST(const vector<uint32_t> &, VT_LIST_INT, i)
  355. ATTR_VALUE_IMP_SET_LIST(const vector<float> &, VT_LIST_FLOAT, f)
  356. ATTR_VALUE_IMP_SET_LIST(const vector<string> &, VT_LIST_STRING, s)
  357. ATTR_VALUE_IMP_SET_LIST(const vector<bool> &, VT_LIST_BOOL, b)
  358. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensorDesc &value) {
  359. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) {
  360. return false;
  361. }
  362. auto proto_msg = value.tensor_descriptor_.GetProtoMsg();
  363. if (proto_msg == nullptr) {
  364. return false;
  365. }
  366. *proto_attr_val.mutable_td() = *proto_msg;
  367. return true;
  368. }
  369. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeTensorDesc> &value) {
  370. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  371. proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC)) {
  372. return false;
  373. }
  374. auto list = proto_attr_val.mutable_list();
  375. GE_CHECK_NOTNULL_EXEC(list, return false);
  376. list->clear_td();
  377. for (const auto &item : value) {
  378. auto proto_msg = item.tensor_descriptor_.GetProtoMsg();
  379. if (proto_msg == nullptr) {
  380. proto_attr_val.clear_list();
  381. return false;
  382. }
  383. *list->add_td() = *proto_msg;
  384. }
  385. return true;
  386. }
  387. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ConstGeTensorPtr &value) {
  388. if (value) {
  389. return SetValue(proto_attr_val, *value);
  390. } else {
  391. return SetValue(proto_attr_val, GeTensor());
  392. }
  393. }
  394. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensor &val) {
  395. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kT)) {
  396. return false;
  397. }
  398. auto proto_msg = val.tensor_def_.GetProtoMsg();
  399. if (proto_msg == nullptr) {
  400. GELOGE(FAILED, "Proto msg is nullptr");
  401. return false;
  402. }
  403. *proto_attr_val.mutable_t() = *proto_msg;
  404. return true;
  405. }
  406. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeTensorPtr> &value) {
  407. vector<ConstGeTensorPtr> constList(value.size());
  408. std::copy(value.begin(), value.end(), constList.begin());
  409. return SetValue(proto_attr_val, constList);
  410. }
  411. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<ConstGeTensorPtr> &value) {
  412. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  413. proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) {
  414. return false;
  415. }
  416. auto list = proto_attr_val.mutable_list();
  417. GE_CHECK_NOTNULL_EXEC(list, return false);
  418. list->clear_t();
  419. for (const auto &item : value) {
  420. if (item == nullptr) {
  421. GELOGE(GRAPH_FAILED, "AttrUtils::SetListTensor item is nullptr");
  422. proto_attr_val.clear_list();
  423. return false;
  424. }
  425. auto proto_msg = item->tensor_def_.GetProtoMsg();
  426. if (proto_msg == nullptr) {
  427. GELOGE(FAILED, "Proto msg is nullptr");
  428. proto_attr_val.clear_list();
  429. return false;
  430. }
  431. *list->add_t() = *proto_msg;
  432. }
  433. return true;
  434. }
  435. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeTensor> &value) {
  436. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  437. proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) {
  438. return false;
  439. }
  440. auto list = proto_attr_val.mutable_list();
  441. GE_CHECK_NOTNULL_EXEC(list, return false);
  442. list->clear_t();
  443. for (const auto &item : value) {
  444. auto proto_msg = item.tensor_def_.GetProtoMsg();
  445. if (proto_msg == nullptr) {
  446. GELOGE(FAILED, "Proto msg is nullptr");
  447. proto_attr_val.clear_list();
  448. return false;
  449. }
  450. *list->add_t() = *proto_msg;
  451. }
  452. return true;
  453. }
  454. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::BYTES &value) {
  455. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
  456. return false;
  457. }
  458. size_t val_size = value.GetSize();
  459. proto_attr_val.set_bt(value.GetData(), val_size);
  460. return true;
  461. }
  462. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::BYTES> &value) {
  463. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  464. proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) {
  465. return false;
  466. }
  467. auto list = proto_attr_val.mutable_list();
  468. GE_CHECK_NOTNULL_EXEC(list, return false);
  469. list->clear_bt();
  470. for (const auto &item : value) {
  471. list->add_bt(item.GetData(), item.GetSize());
  472. }
  473. return true;
  474. }
  475. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) {
  476. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) {
  477. return false;
  478. }
  479. auto proto_msg = value.named_attrs_.GetProtoMsg();
  480. if (proto_msg == nullptr) {
  481. GELOGE(FAILED, "Proto msg is nullptr");
  482. return false;
  483. }
  484. *proto_attr_val.mutable_func() = *proto_msg;
  485. return true;
  486. }
  487. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::NAMED_ATTRS> &value) {
  488. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  489. proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) {
  490. return false;
  491. }
  492. auto list = proto_attr_val.mutable_list();
  493. GE_CHECK_NOTNULL_EXEC(list, return false);
  494. list->clear_na();
  495. for (const auto &item : value) {
  496. auto proto_msg = item.named_attrs_.GetProtoMsg();
  497. if (proto_msg == nullptr) {
  498. proto_attr_val.clear_list();
  499. return false;
  500. }
  501. *list->add_na() = *proto_msg;
  502. }
  503. return true;
  504. }
  505. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::ComputeGraphPtr &value) {
  506. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kG)) {
  507. return false;
  508. }
  509. ModelSerializeImp imp;
  510. if (!imp.SerializeGraph(value, proto_attr_val.mutable_g())) {
  511. GELOGE(GRAPH_FAILED, "AttrUtils::SetGraph SerializeGraph Failed");
  512. proto_attr_val.clear_g();
  513. return false;
  514. }
  515. return true;
  516. }
  517. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<ge::ComputeGraphPtr> &value) {
  518. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  519. proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH)) {
  520. return false;
  521. }
  522. auto list = proto_attr_val.mutable_list();
  523. GE_CHECK_NOTNULL_EXEC(list, return false);
  524. list->clear_g();
  525. ModelSerializeImp imp;
  526. for (const auto &item : value) {
  527. if (!imp.SerializeGraph(item, list->add_g())) {
  528. GELOGE(GRAPH_FAILED, "AttrUtils::SetListGraph SerializeGraph");
  529. proto_attr_val.clear_list();
  530. return false;
  531. }
  532. }
  533. return true;
  534. }
  535. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<vector<int64_t>> &value) {
  536. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) {
  537. return false;
  538. }
  539. proto_attr_val.clear_list_list_int();
  540. auto list_list_int = proto_attr_val.mutable_list_list_int();
  541. GE_CHECK_NOTNULL_EXEC(list_list_int, return false);
  542. for (auto &list_int : value) {
  543. auto list_item = list_list_int->add_list_list_i();
  544. GE_CHECK_NOTNULL_EXEC(list_item, return false);
  545. for (auto &int_item : list_int) {
  546. list_item->add_list_i(int_item);
  547. }
  548. }
  549. return true;
  550. }
  551. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<ge::DataType> &value) {
  552. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  553. proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE)) {
  554. return false;
  555. }
  556. auto list = proto_attr_val.mutable_list();
  557. GE_CHECK_NOTNULL_EXEC(list, return false);
  558. list->clear_dt();
  559. for (const auto &item : value) {
  560. list->add_dt(static_cast<int64_t>(item));
  561. }
  562. return true;
  563. }
  564. bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::DataType &value) {
  565. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) {
  566. return false;
  567. }
  568. proto_attr_val.set_dt(static_cast<int64_t>(value));
  569. return true;
  570. }
  571. #define ATTR_VALUE_IMP_GET_ONE(ValType, proto_case, protoItem) \
  572. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ValType value) { \
  573. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \
  574. return false; \
  575. } \
  576. value = proto_attr_val.protoItem(); \
  577. return true; \
  578. }
  579. #define ListValueItemCheck(protoItem) \
  580. [](const proto::AttrDef &proto_attr_val) { return proto_attr_val.list().protoItem##_size() > 0; }
  581. #define ATTR_VALUE_IMP_GET_LIST(ValType, proto_list_case, protoItem) \
  582. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, vector<ValType> &value) { \
  583. value.clear(); \
  584. if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, \
  585. proto::AttrDef_ListValue_ListValueType_##proto_list_case, \
  586. ListValueItemCheck(protoItem))) { \
  587. return false; \
  588. } \
  589. auto &list = proto_attr_val.list(); \
  590. for (const auto &item : list.protoItem()) { \
  591. value.push_back(item); \
  592. } \
  593. return true; \
  594. }
  595. ATTR_VALUE_IMP_GET_ONE(int64_t &, kI, i)
  596. ATTR_VALUE_IMP_GET_ONE(float &, kF, f)
  597. ATTR_VALUE_IMP_GET_ONE(string &, kS, s)
  598. ATTR_VALUE_IMP_GET_ONE(bool &, kB, b)
  599. ATTR_VALUE_IMP_GET_LIST(int64_t, VT_LIST_INT, i)
  600. ATTR_VALUE_IMP_GET_LIST(float, VT_LIST_FLOAT, f)
  601. ATTR_VALUE_IMP_GET_LIST(string, VT_LIST_STRING, s)
  602. ATTR_VALUE_IMP_GET_LIST(bool, VT_LIST_BOOL, b)
  603. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeTensorDesc &value) {
  604. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) {
  605. return false;
  606. }
  607. auto proto_msg = value.tensor_descriptor_.GetProtoMsg();
  608. if (proto_msg == nullptr) {
  609. return false;
  610. }
  611. *proto_msg = proto_attr_val.td();
  612. return true;
  613. }
  614. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  615. vector<GeTensorDesc> &value) {
  616. if (!AttrUtilsHelper::GetValueCheckListType(
  617. proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, ListValueItemCheck(td))) {
  618. return false;
  619. }
  620. auto &list = proto_attr_val.list();
  621. for (const auto &item : list.td()) {
  622. value.emplace_back(GeTensorDesc());
  623. auto proto_msg = value.back().tensor_descriptor_.GetProtoMsg();
  624. if (proto_msg == nullptr) {
  625. return false;
  626. }
  627. *proto_msg = item;
  628. }
  629. return true;
  630. }
  631. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner,
  632. GeTensorPtr &value) {
  633. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kT)) {
  634. return false;
  635. }
  636. value = std::shared_ptr<GeTensor>(
  637. new (std::nothrow) GeTensor(proto_owner, const_cast<proto::AttrDef &>(proto_attr_val).mutable_t()));
  638. GE_CHK_BOOL_RET_STATUS(value != nullptr, false, "value is nullptr");
  639. return true;
  640. }
  641. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner,
  642. vector<GeTensorPtr> &value) {
  643. value.clear();
  644. if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR,
  645. ListValueItemCheck(t))) {
  646. return false;
  647. }
  648. auto list = const_cast<proto::AttrDef &>(proto_attr_val).mutable_list();
  649. GE_CHECK_NOTNULL_EXEC(list, return false);
  650. for (auto &item : *(list->mutable_t())) {
  651. std::shared_ptr<GeTensor> temp_value = std::shared_ptr<GeTensor>(new (std::nothrow) GeTensor(proto_owner, &item));
  652. GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
  653. value.push_back(temp_value);
  654. }
  655. return true;
  656. }
  657. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeAttrValue::BYTES &value) {
  658. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
  659. return false;
  660. }
  661. auto &proto_val = proto_attr_val.bt();
  662. GE_LOGI_IF(proto_val.size() == 0, "size res is 0.");
  663. value = Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(proto_val.data()), proto_val.size());
  664. return true;
  665. }
  666. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  667. vector<GeAttrValue::BYTES> &value) {
  668. value.clear();
  669. if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES,
  670. ListValueItemCheck(bt))) {
  671. return false;
  672. }
  673. auto &list = proto_attr_val.list();
  674. for (const auto &item : list.bt()) {
  675. value.push_back(Buffer::CopyFrom((const uint8_t *)item.data(), item.size()));
  676. }
  677. return true;
  678. }
  679. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  680. GeAttrValue::NAMED_ATTRS &value) {
  681. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) {
  682. return false;
  683. }
  684. auto proto_msg = value.named_attrs_.GetProtoMsg();
  685. if (proto_msg == nullptr) {
  686. return false;
  687. }
  688. *proto_msg = proto_attr_val.func();
  689. return true;
  690. }
  691. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  692. vector<GeAttrValue::NAMED_ATTRS> &value) {
  693. value.clear();
  694. if (!AttrUtilsHelper::GetValueCheckListType(
  695. proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) {
  696. return false;
  697. }
  698. auto &list = proto_attr_val.list();
  699. for (const auto &item : list.na()) {
  700. value.emplace_back(GeAttrValue::NAMED_ATTRS());
  701. if (value.empty()) {
  702. return false;
  703. }
  704. auto proto_msg = value.back().named_attrs_.GetProtoMsg();
  705. if (proto_msg == nullptr) {
  706. return false;
  707. }
  708. *proto_msg = item;
  709. }
  710. return true;
  711. }
  712. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ComputeGraphPtr &value) {
  713. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kG)) {
  714. return false;
  715. }
  716. ComputeGraphPtr graph = nullptr;
  717. std::shared_ptr<proto::GraphDef> graph_def;
  718. graph_def = ComGraphMakeShared<proto::GraphDef>(proto_attr_val.g());
  719. if (graph_def == nullptr) {
  720. GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
  721. graph_def = nullptr;
  722. return false; // lint !e665
  723. } else {
  724. ModelSerializeImp imp;
  725. imp.SetProtobufOwner(graph_def);
  726. if (!imp.UnserializeGraph(graph, *graph_def)) {
  727. GELOGE(GRAPH_FAILED, "UnserializeGraph Failed");
  728. return false;
  729. } // lint !e514
  730. value = graph;
  731. }
  732. return true;
  733. }
  734. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  735. vector<ComputeGraphPtr> &value) {
  736. value.clear();
  737. if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH,
  738. ListValueItemCheck(g))) {
  739. return false;
  740. }
  741. auto &list = proto_attr_val.list();
  742. for (const auto &item : list.g()) {
  743. std::shared_ptr<proto::GraphDef> graph_def;
  744. graph_def = ComGraphMakeShared<proto::GraphDef>(item);
  745. if (graph_def == nullptr) {
  746. GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
  747. graph_def = nullptr;
  748. return false; // lint !e665
  749. } else {
  750. ComputeGraphPtr graph = nullptr;
  751. ModelSerializeImp imp;
  752. imp.SetProtobufOwner(graph_def);
  753. if (!imp.UnserializeGraph(graph, *graph_def)) {
  754. GELOGE(GRAPH_FAILED, "UnserializeGraph Failed");
  755. return false;
  756. } // lint !e514
  757. value.push_back(graph);
  758. }
  759. }
  760. return true;
  761. }
  762. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  763. vector<vector<int64_t>> &value) {
  764. value.clear();
  765. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) {
  766. return false;
  767. }
  768. auto &list_listint = proto_attr_val.list_list_int().list_list_i();
  769. for (auto &list_int : list_listint) {
  770. vector<int64_t> list_item(list_int.list_i().size());
  771. if (!list_int.list_i().empty()) {
  772. (void)std::copy(list_int.list_i().begin(), list_int.list_i().end(), list_item.begin());
  773. }
  774. value.push_back(list_item);
  775. }
  776. return true;
  777. }
  778. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  779. vector<ge::DataType> &value) {
  780. if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE,
  781. ListValueItemCheck(dt))) {
  782. return false;
  783. }
  784. auto &list = proto_attr_val.list();
  785. for (const auto &item : list.dt()) {
  786. value.emplace_back(static_cast<ge::DataType>(item));
  787. }
  788. return true;
  789. }
  790. bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ge::DataType &value) {
  791. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) {
  792. return false;
  793. }
  794. value = static_cast<ge::DataType>(proto_attr_val.dt());
  795. return true;
  796. }
  797. GE_FUNC_HOST_VISIBILITY bool GeAttrValueImp::SetZeroCopyBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  798. Buffer &&buffer) {
  799. if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
  800. return false;
  801. }
  802. auto proto_msg = buffer.data_.GetProtoMsg();
  803. if (proto_msg == nullptr) {
  804. return false;
  805. }
  806. proto_attr_val.set_bt(std::move(*proto_msg->mutable_bt()));
  807. return true;
  808. }
  809. bool GeAttrValueImp::GetZeroCopyBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner,
  810. Buffer &buffer) {
  811. if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) {
  812. return false;
  813. }
  814. buffer = Buffer(proto_owner, &const_cast<proto::AttrDef &>(proto_attr_val));
  815. return true;
  816. }
  817. bool GeAttrValueImp::SetZeroCopyListBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &,
  818. vector<Buffer> &list_buffer) {
  819. if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val,
  820. proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) {
  821. return false;
  822. }
  823. auto list = proto_attr_val.mutable_list();
  824. GE_CHECK_NOTNULL_EXEC(list, return false);
  825. list->clear_bt();
  826. for (auto &item : list_buffer) {
  827. auto proto_msg = item.data_.GetProtoMsg();
  828. if (proto_msg == nullptr) {
  829. return false;
  830. }
  831. list->add_bt(std::move(*proto_msg->mutable_bt()));
  832. }
  833. return true;
  834. }
  835. bool GeAttrValueImp::GetZeroCopyListBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner,
  836. vector<Buffer> &list_buffer) {
  837. list_buffer.clear();
  838. if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES,
  839. ListValueItemCheck(bt))) {
  840. return false;
  841. }
  842. auto list = const_cast<proto::AttrDef &>(proto_attr_val).mutable_list();
  843. GE_CHECK_NOTNULL_EXEC(list, return false);
  844. for (auto &item : *(list->mutable_bt())) {
  845. list_buffer.emplace_back(Buffer(proto_owner, &item));
  846. }
  847. return true;
  848. }
  849. bool AttrUtils::HasAttr(ConstAttrHolderAdapter &&obj, const string &name) {
  850. if (!obj) {
  851. return false;
  852. }
  853. return obj->HasAttr(name);
  854. }
  855. #define ATTR_UTILS_SET_IMP(FuncName, Type) \
  856. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Set##FuncName( \
  857. AttrHolderAdapter &&obj, const string &name, const Type &value) { \
  858. proto::AttrDef *proto_attr_val = nullptr; \
  859. if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \
  860. return false; \
  861. } \
  862. if (!GeAttrValueImp::SetValue(*proto_attr_val, value)) { \
  863. GELOGW("Set" #FuncName " failed key %s", name.c_str()); \
  864. return false; \
  865. } \
  866. return true; \
  867. }
  868. #define ATTR_UTILS_GET_IMP(FuncName, Type) \
  869. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Get##FuncName(ConstAttrHolderAdapter &&obj, \
  870. const string &name, Type &value) { \
  871. const proto::AttrDef *proto_attr_val = nullptr; \
  872. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \
  873. return false; \
  874. } \
  875. if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value)) { \
  876. GELOGW("Get" #FuncName " failed key %s", name.c_str()); \
  877. return false; \
  878. } \
  879. return true; \
  880. }
  881. #define ATTR_UTILS_SET_GET_IMP(FuncName, Type) \
  882. ATTR_UTILS_SET_IMP(FuncName, Type) \
  883. ATTR_UTILS_GET_IMP(FuncName, Type)
  884. ATTR_UTILS_SET_GET_IMP(Int, int64_t)
  885. ATTR_UTILS_SET_GET_IMP(Float, float)
  886. ATTR_UTILS_SET_GET_IMP(Bool, bool)
  887. ATTR_UTILS_SET_GET_IMP(Str, string)
  888. ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc)
  889. ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr)
  890. ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr)
  891. ATTR_UTILS_SET_IMP(Tensor, GeTensor)
  892. ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS)
  893. ATTR_UTILS_SET_GET_IMP(Bytes, Buffer)
  894. ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr)
  895. /*lint -e665*/
  896. ATTR_UTILS_SET_GET_IMP(ListListInt, vector<vector<int64_t>>)
  897. /*lint +e665*/
  898. ATTR_UTILS_SET_GET_IMP(ListInt, vector<int64_t>)
  899. ATTR_UTILS_SET_IMP(ListInt, vector<int32_t>)
  900. ATTR_UTILS_SET_IMP(ListInt, vector<uint32_t>)
  901. ATTR_UTILS_SET_GET_IMP(ListFloat, vector<float>)
  902. ATTR_UTILS_SET_GET_IMP(ListBool, vector<bool>)
  903. ATTR_UTILS_SET_GET_IMP(ListStr, vector<string>)
  904. ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector<GeTensorDesc>)
  905. ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensorPtr>)
  906. ATTR_UTILS_SET_IMP(ListTensor, vector<ConstGeTensorPtr>)
  907. ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>)
  908. ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>)
  909. ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>)
  910. ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>)
  911. ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) // lint !e665
  912. ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665
  913. bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name,
  914. std::initializer_list<ConstGeTensorPtr> &&value) {
  915. return SetListTensor(std::move(obj), name, vector<ConstGeTensorPtr>(value));
  916. }
  917. bool AttrUtils::GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value) {
  918. const proto::AttrDef *proto_attr_val = nullptr;
  919. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  920. return false;
  921. }
  922. GeTensorPtr tensor;
  923. if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) {
  924. return false;
  925. }
  926. value = tensor;
  927. return true;
  928. }
  929. bool AttrUtils::GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector<ConstGeTensorPtr> &value) {
  930. value.clear();
  931. const proto::AttrDef *proto_attr_val = nullptr;
  932. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  933. return false;
  934. }
  935. vector<GeTensorPtr> tensor;
  936. if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) {
  937. return false;
  938. }
  939. value.insert(value.begin(), tensor.begin(), tensor.end());
  940. return true;
  941. }
  942. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::MutableTensor(AttrHolderAdapter &&obj,
  943. const string &name, GeTensorPtr &value) {
  944. const proto::AttrDef *proto_attr_val = nullptr;
  945. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  946. return false;
  947. }
  948. return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value);
  949. }
  950. bool AttrUtils::MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector<GeTensorPtr> &value) {
  951. value.clear();
  952. const proto::AttrDef *proto_attr_val = nullptr;
  953. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  954. return false;
  955. }
  956. return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value);
  957. }
  958. bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list<int64_t> &&value) {
  959. proto::AttrDef *proto_attr_val = nullptr;
  960. if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  961. return false;
  962. }
  963. return GeAttrValueImp::SetValue(*proto_attr_val, value);
  964. }
  965. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name,
  966. int32_t &value) {
  967. int64_t int64_val = 0;
  968. if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) {
  969. return false;
  970. }
  971. if (int64_val > INT32_MAX) {
  972. GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to int32_t", int64_val);
  973. return false;
  974. }
  975. value = static_cast<int32_t>(int64_val);
  976. return true;
  977. }
  978. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name,
  979. uint32_t &value) {
  980. int64_t int64_val = 0;
  981. if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) {
  982. return false;
  983. }
  984. if (int64_val > UINT32_MAX) {
  985. GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to uint32_t", int64_val);
  986. return false;
  987. }
  988. value = static_cast<uint32_t>(int64_val);
  989. return true;
  990. }
  991. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj,
  992. const string &name, vector<int32_t> &value) {
  993. value.clear();
  994. vector<int64_t> int64_list;
  995. if (!GetListInt(std::move(obj), name, int64_list)) {
  996. return false;
  997. }
  998. for (size_t i = 0; i < int64_list.size(); ++i) {
  999. if (int64_list[i] > INT32_MAX) {
  1000. GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]);
  1001. return false;
  1002. }
  1003. }
  1004. value.insert(value.begin(), int64_list.begin(), int64_list.end());
  1005. return true;
  1006. }
  1007. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj,
  1008. const string &name, vector<uint32_t> &value) {
  1009. value.clear();
  1010. vector<int64_t> int64_list;
  1011. if (!GetListInt(std::move(obj), name, int64_list)) {
  1012. return false;
  1013. }
  1014. for (size_t i = 0; i < int64_list.size(); ++i) {
  1015. if (int64_list[i] > UINT32_MAX) {
  1016. GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]);
  1017. return false;
  1018. }
  1019. }
  1020. value.insert(value.begin(), int64_list.begin(), int64_list.end());
  1021. return true;
  1022. }
  1023. bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<ConstOpDescPtr> &value) {
  1024. if (obj) {
  1025. vector<Buffer> bytes_vals;
  1026. for (auto &item : value) {
  1027. ModelSerialize serialize;
  1028. auto buffer = serialize.SerializeOpDesc(item);
  1029. if (buffer.GetSize() == 0) {
  1030. return false;
  1031. }
  1032. bytes_vals.push_back(buffer);
  1033. }
  1034. return SetZeroCopyListBytes(std::move(obj), name, bytes_vals);
  1035. }
  1036. return false;
  1037. }
  1038. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj,
  1039. const string &name,
  1040. const vector<OpDescPtr> &value) {
  1041. if (obj) {
  1042. vector<Buffer> bytes_vals;
  1043. for (auto &item : value) {
  1044. ModelSerialize serialize;
  1045. auto buffer = serialize.SerializeOpDesc(item);
  1046. if (buffer.GetSize() == 0) {
  1047. return false;
  1048. }
  1049. bytes_vals.push_back(buffer);
  1050. }
  1051. return SetZeroCopyListBytes(std::move(obj), name, bytes_vals);
  1052. }
  1053. return false;
  1054. }
  1055. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(ConstAttrHolderAdapter &&obj,
  1056. const string &name,
  1057. vector<OpDescPtr> &value) {
  1058. value.clear();
  1059. vector<Buffer> bytes_vals;
  1060. if (!GetZeroCopyListBytes(std::move(obj), name, bytes_vals)) {
  1061. return false;
  1062. }
  1063. for (const auto &item : bytes_vals) {
  1064. ModelSerialize serialize;
  1065. auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732
  1066. value.push_back(op_desc);
  1067. }
  1068. return true;
  1069. }
  1070. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetZeroCopyBytes(AttrHolderAdapter &&obj,
  1071. const string &name, Buffer &&buffer) {
  1072. // Value will be moved
  1073. proto::AttrDef *proto_attr_val = nullptr;
  1074. if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  1075. return false;
  1076. }
  1077. return GeAttrValueImp::SetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), std::move(buffer));
  1078. }
  1079. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetZeroCopyBytes(ConstAttrHolderAdapter &&obj,
  1080. const string &name, Buffer &buffer) {
  1081. const proto::AttrDef *proto_attr_val = nullptr;
  1082. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  1083. return false;
  1084. }
  1085. return GeAttrValueImp::GetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), buffer);
  1086. }
  1087. bool AttrUtils::SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector<Buffer> &list_buffer) {
  1088. // Value will be moved
  1089. proto::AttrDef *proto_attr_val = nullptr;
  1090. if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  1091. return false;
  1092. }
  1093. return GeAttrValueImp::SetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer);
  1094. }
  1095. bool AttrUtils::GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &list_buffer) {
  1096. list_buffer.clear();
  1097. const proto::AttrDef *proto_attr_val = nullptr;
  1098. if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) {
  1099. return false;
  1100. }
  1101. return GeAttrValueImp::GetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer);
  1102. }
  1103. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) {
  1104. if (org_op_desc == nullptr) {
  1105. GELOGE(GRAPH_FAILED, "org_op_desc is null");
  1106. return nullptr;
  1107. }
  1108. std::shared_ptr<proto::OpDef> op_def;
  1109. op_def = ComGraphMakeShared<proto::OpDef>();
  1110. if (op_def == nullptr) {
  1111. GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
  1112. return nullptr; // lint !e665
  1113. }
  1114. ModelSerializeImp imp;
  1115. (void)imp.SerializeOpDesc(org_op_desc, op_def.get());
  1116. imp.SetProtobufOwner(op_def);
  1117. OpDescPtr op_desc = nullptr;
  1118. GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed");
  1119. op_desc->extAttrs_ = org_op_desc->extAttrs_;
  1120. // This function may be called by some passes of fusion engine, in this condition, do not need these attribute
  1121. if (!op_desc->input_name_idx_.empty()) {
  1122. op_desc->input_name_idx_.clear();
  1123. }
  1124. if (!op_desc->output_name_idx_.empty()) {
  1125. op_desc->output_name_idx_.clear();
  1126. }
  1127. if (!op_desc->optional_input_names_.empty()) {
  1128. op_desc->optional_input_names_.clear();
  1129. }
  1130. return op_desc;
  1131. }
  1132. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) {
  1133. if (org_op_desc == nullptr) {
  1134. GELOGE(GRAPH_FAILED, "org_op_desc is null");
  1135. return nullptr;
  1136. }
  1137. std::shared_ptr<proto::OpDef> op_def = ComGraphMakeShared<proto::OpDef>();
  1138. if (op_def == nullptr) {
  1139. GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
  1140. return nullptr;
  1141. }
  1142. ModelSerializeImp imp;
  1143. (void)imp.SerializeOpDesc(org_op_desc, op_def.get());
  1144. imp.SetProtobufOwner(op_def);
  1145. OpDescPtr op_desc = nullptr;
  1146. GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed");
  1147. op_desc->extAttrs_ = org_op_desc->extAttrs_;
  1148. op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end());
  1149. op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(),
  1150. org_op_desc->optional_input_names_.end());
  1151. op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end());
  1152. op_desc->infer_func_ = org_op_desc->infer_func_;
  1153. op_desc->infer_format_func_ = org_op_desc->infer_format_func_;
  1154. op_desc->verifier_func_ = org_op_desc->verifier_func_;
  1155. return op_desc;
  1156. }
  1157. std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) {
  1158. auto holder = obj.get();
  1159. if (holder == nullptr) {
  1160. return "";
  1161. }
  1162. auto attrs_map = holder->GetAttrMap();
  1163. if (attrs_map.GetProtoMsg() == nullptr) {
  1164. return "";
  1165. }
  1166. std::map<std::string, std::string> ordered_attrs;
  1167. for (auto &attr : *(attrs_map.GetProtoMsg())) {
  1168. ordered_attrs[attr.first] = attr.second.SerializeAsString();
  1169. }
  1170. std::stringstream ss;
  1171. for (auto &attr : ordered_attrs) {
  1172. ss << attr.first << ":" << attr.second << ";";
  1173. }
  1174. return ss.str();
  1175. }
  1176. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示