You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_desc.cc 54 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/op_desc.h"
  17. #include "debug/ge_attr_define.h"
  18. #include "debug/ge_util.h"
  19. #include "external/graph/operator.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "common/util/error_manager/error_manager.h"
  22. #include "graph/common_error_codes.h"
  23. #include "graph/ge_attr_value.h"
  24. #include "graph/ge_tensor.h"
  25. #include "graph/operator_factory_impl.h"
  26. #include "graph/utils/attr_utils.h"
  27. #include "graph/utils/ge_ir_utils.h"
  28. #include "graph/utils/op_desc_utils.h"
  29. #include "graph/utils/transformer_utils.h"
  30. #include "proto/ge_ir.pb.h"
  31. using std::make_pair;
  32. using std::shared_ptr;
  33. using std::string;
  34. using std::vector;
  35. /*lint -save -e521 -e681 -e732 -e737*/
  36. namespace ge {
  37. const std::string ATTR_NAME_ID = "id";
  38. const std::string ATTR_NAME_STREAM_ID = "stream_id";
  39. const std::string ATTR_NAME_INPUT_NAME = "input_name";
  40. const std::string ATTR_NAME_SRC_NAME = "src_name";
  41. const std::string ATTR_NAME_SRC_INDEX = "src_index";
  42. const std::string ATTR_NAME_INPUT = "input";
  43. const std::string ATTR_NAME_OUTPUT = "output";
  44. const std::string ATTR_NAME_INPUT_DESC = "input_desc";
  45. const std::string ATTR_NAME_OUTPUT_DESC = "output_desc";
  46. const std::string ATTR_NAME_DST_NAME = "dst_name";
  47. const std::string ATTR_NAME_DST_INDEX = "dst_index";
  48. const std::string ATTR_NAME_WORKSPACE = "workspace";
  49. const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes";
  50. const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const";
  51. const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends";
  52. const std::string ATTR_NAME_OP_KERNEL_LIB_NAME = "_ge_attr_op_kernel_lib_name";
  53. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() {
  54. op_def_.InitDefault();
  55. if (op_def_.GetProtoMsg() != nullptr) {
  56. op_def_.GetProtoMsg()->set_has_out_attr(true);
  57. }
  58. }
  59. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::~OpDesc() {}
  60. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const std::string &name, const std::string &type) {
  61. op_def_.InitDefault();
  62. if (op_def_.GetProtoMsg() != nullptr) {
  63. op_def_.GetProtoMsg()->set_has_out_attr(true);
  64. }
  65. SetName(name);
  66. SetType(type);
  67. }
  68. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const ProtoMsgOwner &proto_msg_owner,
  69. ge::proto::OpDef *op_def)
  70. : op_def_(proto_msg_owner, op_def) {
  71. if (op_def != nullptr && !op_def->has_out_attr()) {
  72. op_def->set_has_out_attr(true);
  73. int64_t id = 0;
  74. (void)AttrUtils::GetInt(this, ATTR_NAME_ID, id);
  75. op_def->set_id(id);
  76. int64_t stream_id = 0;
  77. (void)AttrUtils::GetInt(this, ATTR_NAME_STREAM_ID, stream_id);
  78. op_def->set_stream_id(stream_id);
  79. vector<string> input_name;
  80. (void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME, input_name);
  81. for (auto &item : input_name) {
  82. op_def->add_input_name(item);
  83. }
  84. vector<string> src_name;
  85. (void)AttrUtils::GetListStr(this, ATTR_NAME_SRC_NAME, src_name);
  86. for (auto &item : src_name) {
  87. op_def->add_src_name(item);
  88. }
  89. vector<int64_t> src_index;
  90. (void)AttrUtils::GetListInt(this, ATTR_NAME_SRC_INDEX, src_index);
  91. for (auto &item : src_index) {
  92. op_def->add_src_index(item);
  93. }
  94. vector<int64_t> input;
  95. (void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT, input);
  96. for (auto &item : input) {
  97. op_def->add_input_i(item);
  98. }
  99. vector<int64_t> output;
  100. (void)AttrUtils::GetListInt(this, ATTR_NAME_OUTPUT, output);
  101. for (auto &item : output) {
  102. op_def->add_output_i(item);
  103. }
  104. vector<string> dst_name;
  105. (void)AttrUtils::GetListStr(this, ATTR_NAME_DST_NAME, dst_name);
  106. for (auto &item : dst_name) {
  107. op_def->add_dst_name(item);
  108. }
  109. vector<int64_t> dst_index;
  110. (void)AttrUtils::GetListInt(this, ATTR_NAME_DST_INDEX, dst_index);
  111. for (auto &item : dst_index) {
  112. op_def->add_dst_index(item);
  113. }
  114. vector<int64_t> workspace;
  115. (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE, workspace);
  116. for (auto &item : workspace) {
  117. op_def->add_workspace(item);
  118. }
  119. vector<int64_t> workspace_bytes;
  120. (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE_BYTES, workspace_bytes);
  121. for (auto &item : workspace_bytes) {
  122. op_def->add_workspace_bytes(item);
  123. }
  124. vector<bool> is_input_const;
  125. (void)AttrUtils::GetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const);
  126. for (auto item : is_input_const) {
  127. op_def->add_is_input_const(item);
  128. }
  129. auto input_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_INPUT_DESC].mutable_list();
  130. if (input_desc_mutable_list != nullptr) {
  131. *op_def->mutable_input_desc() = *(input_desc_mutable_list->mutable_td());
  132. }
  133. auto output_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_OUTPUT_DESC].mutable_list();
  134. if (output_desc_mutable_list != nullptr) {
  135. *op_def->mutable_output_desc() = *(output_desc_mutable_list->mutable_td());
  136. }
  137. }
  138. }
  139. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetName() const {
  140. auto proto_msg = op_def_.GetProtoMsg();
  141. if (proto_msg != nullptr) {
  142. return proto_msg->name();
  143. }
  144. return "";
  145. }
  146. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetName(const std::string &name) {
  147. auto proto_msg = op_def_.GetProtoMsg();
  148. if (proto_msg != nullptr) {
  149. proto_msg->set_name(name);
  150. }
  151. }
  152. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetType() const {
  153. auto proto_msg = op_def_.GetProtoMsg();
  154. if (proto_msg != nullptr) {
  155. return proto_msg->type();
  156. }
  157. return "";
  158. }
  159. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetType(const string &type) {
  160. auto proto_msg = op_def_.GetProtoMsg();
  161. if (proto_msg != nullptr) {
  162. proto_msg->set_type(type);
  163. }
  164. }
  165. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddInputDesc(const ge::GeTensorDesc &input_desc) {
  166. int index = static_cast<int>(inputs_desc_.size());
  167. return AddInputDesc("__input" + std::to_string(index), input_desc);
  168. }
  169. graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc) {
  170. graphStatus ret = GRAPH_SUCCESS;
  171. if (index < inputs_desc_.size()) {
  172. // InputsDesc[index] is exist, then update it
  173. ret = UpdateInputDesc(index, input_desc);
  174. } else {
  175. // InputDesc[index] is not exist, then add it
  176. ret = AddInputDesc(input_desc);
  177. }
  178. return ret;
  179. }
  180. graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) {
  181. if (input_name_idx_.find(name) != input_name_idx_.end()) {
  182. GELOGI("input %s is exist, update it", name.c_str());
  183. graphStatus ret = UpdateInputDesc(name, input_desc);
  184. return ret;
  185. } else {
  186. int index = static_cast<int>(inputs_desc_.size());
  187. std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(input_desc);
  188. if (in_desc == nullptr) {
  189. GELOGE(GRAPH_FAILED, "AddInputDesc failed, malloc shared_ptr failed.");
  190. return GRAPH_FAILED;
  191. }
  192. inputs_desc_.push_back(in_desc);
  193. (void)input_name_idx_.insert(make_pair(name, index));
  194. if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) {
  195. register_input_name_.push_back(name);
  196. }
  197. return GRAPH_SUCCESS;
  198. }
  199. }
  200. graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) {
  201. for (unsigned int i = 0; i < num; i++) {
  202. string input_name = name + std::to_string(i);
  203. GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED,
  204. "Add input tensor_desc is existed. name[%s]", input_name.c_str());
  205. std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc());
  206. if (in_desc == nullptr) {
  207. GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed.");
  208. return GRAPH_FAILED;
  209. }
  210. if (index > inputs_desc_.size()) {
  211. GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size.");
  212. return GRAPH_FAILED;
  213. }
  214. (void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc);
  215. // Update index in input_name_idx
  216. for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) {
  217. if (it->second >= (index + i)) {
  218. it->second += 1;
  219. }
  220. }
  221. (void)input_name_idx_.insert(make_pair(input_name, i + index));
  222. }
  223. return GRAPH_SUCCESS;
  224. }
  225. graphStatus OpDesc::AddOutputDescMiddle(const string &name, const unsigned int num, size_t index) {
  226. for (unsigned int i = 0; i < num; i++) {
  227. string output_name = name + std::to_string(i);
  228. GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED,
  229. "Add input tensor_desc is existed. name[%s]", output_name.c_str());
  230. std::shared_ptr<GeTensorDesc> out_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc());
  231. if (out_desc == nullptr) {
  232. GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed.");
  233. return GRAPH_FAILED;
  234. }
  235. if (index > outputs_desc_.size()) {
  236. GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size.");
  237. return GRAPH_FAILED;
  238. }
  239. (void)outputs_desc_.insert(outputs_desc_.begin() + index + i, out_desc);
  240. // Update index in input_name_idx
  241. for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) {
  242. if (it->second >= (index + i)) {
  243. it->second += 1;
  244. }
  245. }
  246. (void)output_name_idx_.insert(make_pair(output_name, i + index));
  247. }
  248. return GRAPH_SUCCESS;
  249. }
  250. graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) {
  251. for (unsigned int i = 0; i < num; i++) {
  252. string input_name = name + std::to_string(i);
  253. GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED,
  254. "Add input tensor_desc is existed. name[%s]", input_name.c_str());
  255. std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc());
  256. if (in_desc == nullptr) {
  257. GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed.");
  258. return GRAPH_FAILED;
  259. }
  260. (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc);
  261. // Update index in input_name_idx
  262. for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) {
  263. it->second += 1;
  264. }
  265. (void)input_name_idx_.insert(make_pair(input_name, 0));
  266. }
  267. return GRAPH_SUCCESS;
  268. }
  269. graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int num) {
  270. for (unsigned int i = 0; i < num; i++) {
  271. string output_name = name + std::to_string(i);
  272. GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED,
  273. "Add output tensor_desc is existed. name[%s]", output_name.c_str());
  274. std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc());
  275. if (in_desc == nullptr) {
  276. GELOGE(GRAPH_FAILED, "AddOutputDescForward failed, malloc shared_ptr failed.");
  277. return GRAPH_FAILED;
  278. }
  279. (void)outputs_desc_.insert(outputs_desc_.begin(), in_desc);
  280. // Update index in output_name_idx
  281. for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) {
  282. it->second += 1;
  283. }
  284. (void)output_name_idx_.insert(make_pair(output_name, 0));
  285. }
  286. return GRAPH_SUCCESS;
  287. }
  288. graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) {
  289. if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED;
  290. (void)optional_input_names_.insert(name);
  291. return GRAPH_SUCCESS;
  292. }
  293. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  294. OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) {
  295. if (index >= inputs_desc_.size()) {
  296. GELOGW("The index is invalid. index[%u]", index);
  297. return GRAPH_FAILED;
  298. }
  299. inputs_desc_[index] = ComGraphMakeShared<GeTensorDesc>(tensor_Desc);
  300. if (inputs_desc_[index] == nullptr) {
  301. GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed.");
  302. return GRAPH_FAILED;
  303. }
  304. return GRAPH_SUCCESS;
  305. }
  306. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const {
  307. return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") &&
  308. IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") &&
  309. IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") &&
  310. IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") &&
  311. IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_"));
  312. }
  313. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const {
  314. const auto &op_def = this->op_def_.GetProtoMsg();
  315. const auto &r_op_def = r_op_desc.op_def_.GetProtoMsg();
  316. if ((op_def != nullptr) && (r_op_def != nullptr)) {
  317. // Message OpDef in ge_ir.proto
  318. return (
  319. IsEqual(op_def->name(), r_op_def->name(), "OpDef_.name()") &&
  320. IsEqual(op_def->type(), r_op_def->type(), "OpDef_.type()") &&
  321. IsEqual(ToString(op_def->input()), ToString(r_op_def->input()), "OpDef_.input()") &&
  322. IsEqual(op_def->has_out_attr(), r_op_def->has_out_attr(), "OpDef_.has_out_attr()") &&
  323. IsEqual(op_def->stream_id(), r_op_def->stream_id(), "OpDef_.stream_id()") &&
  324. IsEqual(ToString(op_def->input_name()), ToString(r_op_def->input_name()), "OpDef_.input_name()") &&
  325. IsEqual(ToString(op_def->src_name()), ToString(r_op_def->src_name()), "OpDef_.src_name()") &&
  326. IsEqual(ToString(op_def->dst_name()), ToString(r_op_def->dst_name()), "OpDef_.dst_name()") &&
  327. IsEqual(ToString(op_def->src_index()), ToString(r_op_def->src_index()), "OpDef_.src_index()") &&
  328. IsEqual(ToString(op_def->dst_index()), ToString(r_op_def->dst_index()), "OpDef_.dst_index()") &&
  329. IsEqual(ToString(op_def->input_i()), ToString(r_op_def->input_i()), "OpDef_.input_i()") &&
  330. IsEqual(ToString(op_def->output_i()), ToString(r_op_def->output_i()), "OpDef_.output_i()") &&
  331. IsEqual(ToString(op_def->workspace()), ToString(r_op_def->workspace()), "OpDef_.workspace()") &&
  332. IsEqual(ToString(op_def->workspace_bytes()), ToString(r_op_def->workspace_bytes()),
  333. "OpDef_.workspace_bytes()") &&
  334. IsEqual(ToString(op_def->is_input_const()), ToString(r_op_def->is_input_const()), "OpDef_.is_input_const()"));
  335. } else {
  336. return ((op_def == nullptr) && (r_op_def == nullptr));
  337. }
  338. }
  339. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc)
  340. const {
  341. // 1.Verify inputs and outputs desc size
  342. const auto inputs_desc_size = this->inputs_desc_.size();
  343. const auto r_inputs_desc_size = r_op_desc.inputs_desc_.size();
  344. if (inputs_desc_size != r_inputs_desc_size) {
  345. GELOGE(GRAPH_FAILED, "Size of OpDesc's inputs desc verify failed, node name: %s.", this->GetName().c_str());
  346. return false;
  347. }
  348. const auto outputs_desc_size = this->outputs_desc_.size();
  349. const auto r_outputs_desc_size = r_op_desc.outputs_desc_.size();
  350. if (outputs_desc_size != r_outputs_desc_size) {
  351. GELOGE(GRAPH_FAILED, "Size of OpDesc's outputs desc verify failed, node name: %s.", this->GetName().c_str());
  352. return false;
  353. }
  354. // 2.Verify all inputs desc equal
  355. for (uint32_t i = 0; i < inputs_desc_size; i++) {
  356. const auto &in_ge_tensor_desc = this->GetInputDesc(i);
  357. const auto &r_in_ge_tensor_desc = r_op_desc.GetInputDesc(i);
  358. // Determine the connection relationship by GeTensorDesc
  359. if (!(in_ge_tensor_desc == r_in_ge_tensor_desc)) {
  360. GELOGE(GRAPH_FAILED, "Link info of OpDesc's inputs desc verify failed, OpDesc name: %s.",
  361. this->GetName().c_str());
  362. return false;
  363. }
  364. }
  365. // 3.Verify all outputs desc equal
  366. for (uint32_t i = 0; i < outputs_desc_size; i++) {
  367. const auto &out_ge_tensor_desc = this->GetOutputDesc(i);
  368. const auto &r_out_ge_tensor_desc = r_op_desc.GetOutputDesc(i);
  369. if (!(out_ge_tensor_desc == r_out_ge_tensor_desc)) {
  370. GELOGE(GRAPH_FAILED, "Link info of OpDesc's outputs desc verify failed, OpDesc name: %s.",
  371. this->GetName().c_str());
  372. return false;
  373. }
  374. }
  375. return true;
  376. }
  377. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpDesc &r_op_desc) const {
  378. return (OpDescAttrsAreEqual(r_op_desc) && OpDescMembersAreEqual(r_op_desc) &&
  379. OpDescGenTensorDescsAreEqual(r_op_desc));
  380. }
  381. graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) {
  382. auto it = input_name_idx_.find(name);
  383. if (it == input_name_idx_.end()) {
  384. GELOGW("Cann't find the input desc. name[%s]", name.c_str());
  385. return GRAPH_FAILED;
  386. }
  387. if (it->second >= inputs_desc_.size()) {
  388. GELOGE(GRAPH_FAILED, "[%d] more than size of inputs_desc_", it->second);
  389. return GRAPH_FAILED;
  390. }
  391. GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid.");
  392. return GRAPH_FAILED);
  393. inputs_desc_[it->second] = ComGraphMakeShared<GeTensorDesc>(tensor_Desc);
  394. if (inputs_desc_[it->second] == nullptr) {
  395. GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed.");
  396. return GRAPH_FAILED;
  397. }
  398. return GRAPH_SUCCESS;
  399. }
  400. bool OpDesc::InputIsSet(const string &name) const {
  401. auto it = input_name_idx_.find(name);
  402. if (it != input_name_idx_.end()) {
  403. GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false);
  404. auto tensor_desc = inputs_desc_[it->second];
  405. GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false);
  406. auto dims = tensor_desc->GetShape().GetDims();
  407. if (dims.size() > 0) {
  408. return true;
  409. }
  410. }
  411. return false;
  412. }
  413. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc(uint32_t index) const {
  414. GE_CHK_BOOL_RET_STATUS_NOLOG(index < inputs_desc_.size(), GeTensorDesc());
  415. return *(inputs_desc_[index].get());
  416. }
  417. GeTensorDesc OpDesc::GetInputDesc(const string &name) const {
  418. auto it = input_name_idx_.find(name);
  419. GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc());
  420. GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc());
  421. return *(inputs_desc_[it->second].get());
  422. }
  423. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const {
  424. GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index);
  425. if (inputs_desc_[index] == nullptr) {
  426. return nullptr;
  427. }
  428. if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) {
  429. GELOGW("input desc is invalid");
  430. return nullptr;
  431. }
  432. return inputs_desc_[index];
  433. }
  434. GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const {
  435. auto input_name_idx = GetAllInputName();
  436. auto it = input_name_idx.find(name);
  437. if (it == input_name_idx.end()) {
  438. GELOGW("Failed to get [%s] input desc", name.c_str());
  439. return nullptr;
  440. }
  441. return MutableInputDesc(it->second);
  442. }
  443. GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const {
  444. vector<string> names;
  445. if (input_name_idx_.empty()) {
  446. return OpDesc::Vistor<string>(shared_from_this(), names);
  447. }
  448. for (std::pair<string, uint32_t> input : input_name_idx_) {
  449. names.push_back(input.first);
  450. }
  451. return OpDesc::Vistor<string>(shared_from_this(), names);
  452. }
  453. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpKernelLibName(const std::string &name) {
  454. op_kernel_lib_name_ = name;
  455. auto ret = AttrUtils::SetStr(this, ATTR_NAME_OP_KERNEL_LIB_NAME, name);
  456. if (ret != true) {
  457. GELOGE(GRAPH_FAILED, "set op kernel lib name failed.");
  458. }
  459. }
  460. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpKernelLibName() const {
  461. if (!op_kernel_lib_name_.empty()) {
  462. return op_kernel_lib_name_;
  463. }
  464. string op_kernel_lib_name;
  465. (void)AttrUtils::GetStr(this, ATTR_NAME_OP_KERNEL_LIB_NAME, op_kernel_lib_name);
  466. return op_kernel_lib_name;
  467. }
  468. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(const std::string &name) {
  469. engine_name_ = name;
  470. }
  471. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; }
  472. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllInputsDesc() const {
  473. vector<GeTensorDesc> temp{};
  474. for (const auto &it : inputs_desc_) {
  475. if (it->IsValid() == GRAPH_SUCCESS) {
  476. temp.push_back(*it);
  477. } else {
  478. GELOGW("this inputDesc is InValid, it won't be return");
  479. continue;
  480. }
  481. }
  482. return OpDesc::Vistor<GeTensorDesc>(shared_from_this(), temp);
  483. }
  484. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDescPtr> OpDesc::GetAllInputsDescPtr() const {
  485. vector<GeTensorDescPtr> temp{};
  486. for (const auto &it : inputs_desc_) {
  487. if (it->IsValid() == GRAPH_SUCCESS) {
  488. temp.push_back(it);
  489. } else {
  490. GELOGW("this inputDesc is InValid, it won't be return");
  491. continue;
  492. }
  493. }
  494. return OpDesc::Vistor<GeTensorDescPtr>(shared_from_this(), temp);
  495. }
  496. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() const {
  497. // Just return valid inputs size.InValid desc is set in default OPTION_INPUT register.
  498. size_t size = 0;
  499. for (const auto &it : inputs_desc_) {
  500. if (it->IsValid() == GRAPH_SUCCESS) {
  501. size++;
  502. }
  503. }
  504. return size;
  505. }
  506. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); }
  507. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) {
  508. int index = static_cast<int>(outputs_desc_.size());
  509. return AddOutputDesc("__output" + std::to_string(index), output_desc);
  510. }
  511. graphStatus OpDesc::AddOutputDesc(const string &name, const ge::GeTensorDesc &output_desc) {
  512. GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(name) == output_name_idx_.end()), GRAPH_FAILED,
  513. "Add output tensor_Desc is existed. name[%s]", name.c_str());
  514. int index = static_cast<int>(outputs_desc_.size());
  515. std::shared_ptr<GeTensorDesc> tensor = ComGraphMakeShared<GeTensorDesc>(output_desc);
  516. if (tensor == nullptr) {
  517. GELOGE(GRAPH_FAILED, "AddOutputDesc failed, malloc shared_ptr failed.");
  518. return GRAPH_FAILED;
  519. }
  520. outputs_desc_.push_back(tensor);
  521. (void)output_name_idx_.insert(make_pair(name, index));
  522. if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) {
  523. register_output_name_.push_back(name);
  524. }
  525. return GRAPH_SUCCESS;
  526. }
  527. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  528. OpDesc::UpdateOutputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) {
  529. GE_CHK_BOOL_RET_STATUS((index < outputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index);
  530. outputs_desc_[index] = ComGraphMakeShared<GeTensorDesc>(tensor_Desc);
  531. if (outputs_desc_[index] == nullptr) {
  532. GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed.");
  533. return GRAPH_FAILED;
  534. }
  535. return GRAPH_SUCCESS;
  536. }
  537. graphStatus OpDesc::UpdateOutputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) {
  538. auto it = output_name_idx_.find(name);
  539. if (it == output_name_idx_.end()) {
  540. GELOGW("Cann't find the output desc. name[%s]", name.c_str());
  541. return GRAPH_FAILED;
  542. }
  543. GE_IF_BOOL_EXEC(it->second >= outputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid.");
  544. return GRAPH_FAILED);
  545. outputs_desc_[it->second] = ComGraphMakeShared<GeTensorDesc>(tensor_Desc);
  546. if (outputs_desc_[it->second] == nullptr) {
  547. GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed.");
  548. return GRAPH_FAILED;
  549. }
  550. return GRAPH_SUCCESS;
  551. }
  552. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetOutputDesc(uint32_t index) const {
  553. GE_CHK_BOOL_RET_STATUS_NOLOG(index < outputs_desc_.size(), GeTensorDesc());
  554. return *(outputs_desc_[index].get());
  555. }
  556. GeTensorDesc OpDesc::GetOutputDesc(const string &name) const {
  557. auto it = output_name_idx_.find(name);
  558. GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), GeTensorDesc());
  559. GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < outputs_desc_.size(), GeTensorDesc());
  560. return *(outputs_desc_[it->second].get());
  561. }
  562. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(uint32_t index) const {
  563. GE_CHK_BOOL_RET_STATUS(index < outputs_desc_.size(), nullptr, "Cann't find the output desc %u", index);
  564. return outputs_desc_[index];
  565. }
  566. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const {
  567. auto it = output_name_idx_.find(name);
  568. if (it == output_name_idx_.end()) {
  569. GELOGW("Failed to get [%s] output desc", name.c_str());
  570. return nullptr;
  571. }
  572. return MutableOutputDesc(it->second);
  573. }
  574. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const {
  575. return static_cast<uint32_t>(outputs_desc_.size());
  576. }
  577. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllOutputsDesc() const {
  578. vector<GeTensorDesc> temp{};
  579. for (const auto &it : outputs_desc_) {
  580. temp.push_back(*it);
  581. }
  582. return OpDesc::Vistor<GeTensorDesc>(shared_from_this(), temp);
  583. }
  584. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDescPtr> OpDesc::GetAllOutputsDescPtr() const {
  585. return OpDesc::Vistor<GeTensorDescPtr>(shared_from_this(), outputs_desc_);
  586. }
  587. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetOutputsSize() const { return outputs_desc_.size(); }
  588. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetOutputDescPtr(uint32_t index) const {
  589. GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast<uint32_t>(outputs_desc_.size()), nullptr);
  590. return outputs_desc_[index];
  591. }
  592. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(uint32_t index) const {
  593. GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast<uint32_t>(inputs_desc_.size()), nullptr);
  594. if (inputs_desc_[index] == nullptr) {
  595. return nullptr;
  596. }
  597. if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) {
  598. GELOGW("inputsDesc[%u] is InValid", index);
  599. return nullptr;
  600. } else {
  601. return inputs_desc_[static_cast<size_t>(index)];
  602. }
  603. }
  604. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr
  605. OpDesc::GetInputDescPtrDfault(uint32_t index) const {
  606. GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr);
  607. return inputs_desc_[(int32_t)index];
  608. }
  609. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const {
  610. auto it = input_name_idx_.find(name);
  611. GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), shared_ptr<const GeTensorDesc>());
  612. return inputs_desc_[it->second];
  613. }
  614. graphStatus OpDesc::AddRegisterInputName(const std::string &name) {
  615. if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) {
  616. register_input_name_.push_back(name);
  617. }
  618. return GRAPH_SUCCESS;
  619. }
  620. vector<string> OpDesc::GetRegisterInputName() const {
  621. return register_input_name_;
  622. }
  623. graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) {
  624. if (is_push_back) {
  625. for (unsigned int i = 0; i < num; i++) {
  626. if (AddInputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS)
  627. return GRAPH_FAILED;
  628. }
  629. } else {
  630. if (AddInputDescForward(name, num) != GRAPH_SUCCESS)
  631. return GRAPH_FAILED;
  632. }
  633. if (AddRegisterInputName(name) != GRAPH_SUCCESS) {
  634. return GRAPH_FAILED;
  635. }
  636. return GRAPH_SUCCESS;
  637. }
  638. graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index) {
  639. if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) {
  640. return GRAPH_FAILED;
  641. }
  642. return GRAPH_SUCCESS;
  643. }
  644. graphStatus OpDesc::AddRegisterOutputName(const string &name) {
  645. if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) {
  646. register_output_name_.push_back(name);
  647. }
  648. return GRAPH_SUCCESS;
  649. }
  650. vector<string> OpDesc::GetRegisterOutputName() const {
  651. return register_output_name_;
  652. }
  653. graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) {
  654. if (is_push_back) {
  655. for (unsigned int i = 0; i < num; i++) {
  656. if (AddOutputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS)
  657. return GRAPH_FAILED;
  658. }
  659. } else {
  660. if (AddOutputDescForward(name, num) != GRAPH_SUCCESS)
  661. return GRAPH_FAILED;
  662. }
  663. if (AddRegisterOutputName(name) != GRAPH_SUCCESS) {
  664. return GRAPH_FAILED;
  665. }
  666. return GRAPH_SUCCESS;
  667. }
  668. bool OpDesc::IsOptionalInput(const string &name) const {
  669. return optional_input_names_.find(name) != optional_input_names_.end();
  670. }
  671. bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); }
  672. std::map<string, uint32_t> OpDesc::GetAllInputName() const { return input_name_idx_; }
  673. std::map<string, uint32_t> OpDesc::GetAllOutputName() { return output_name_idx_; }
  674. std::map<string, uint32_t>& OpDesc::MutableAllInputName() { return input_name_idx_; }
  675. std::map<string, uint32_t>& OpDesc::MutableAllOutputName() { return output_name_idx_; }
  676. bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) {
  677. bool ret = true;
  678. // Use inputDesc_.size() to contain the InValid OptionInput.GetInputsSize() will remove default OptionInput name.
  679. auto input_map_size = inputs_desc_.size();
  680. auto factory_map_size = input_name_idx.size();
  681. // It indicates that some inputs have no optionalname.
  682. // The redundant optionalname of factory needs to be deleted and then assigned
  683. if (input_map_size < factory_map_size) {
  684. GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size,
  685. factory_map_size);
  686. for (auto it = input_name_idx.begin(); it != input_name_idx.end();) {
  687. if (it->second >= input_map_size) {
  688. it = input_name_idx.erase(it);
  689. } else {
  690. ++it;
  691. }
  692. }
  693. if (input_name_idx.size() == input_map_size) {
  694. GELOGI("UpdateInputName");
  695. input_name_idx_ = input_name_idx;
  696. } else {
  697. ret = false;
  698. GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size());
  699. }
  700. } else if (input_map_size == factory_map_size) {
  701. input_name_idx_ = input_name_idx;
  702. } else {
  703. ret = false;
  704. GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size);
  705. }
  706. return ret;
  707. }
  708. bool OpDesc::UpdateOutputName(std::map<string, uint32_t> output_name_idx) {
  709. size_t output_map_size = GetAllOutputsDescSize();
  710. size_t factory_map_size = output_name_idx.size();
  711. if (output_map_size < factory_map_size) {
  712. GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size,
  713. factory_map_size);
  714. for (auto it = output_name_idx.begin(); it != output_name_idx.end();) {
  715. if (it->second >= output_map_size) {
  716. it = output_name_idx.erase(it);
  717. } else {
  718. ++it;
  719. }
  720. }
  721. if (output_name_idx.size() == output_map_size) {
  722. GELOGI("UpdateoutputName");
  723. output_name_idx_ = output_name_idx;
  724. return true;
  725. }
  726. } else if (output_map_size == factory_map_size) {
  727. output_name_idx_ = output_name_idx;
  728. return true;
  729. } else {
  730. GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size);
  731. return false;
  732. }
  733. GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size);
  734. return false;
  735. }
  736. std::function<graphStatus(Operator &)> OpDesc::GetInferFunc() const { return infer_func_; }
  737. std::function<graphStatus(Operator &)> OpDesc::GetVerifyFunc() const { return verifier_func_; }
  738. void OpDesc::AddInferFunc(const std::function<graphStatus(Operator &)> &func) { infer_func_ = func; }
  739. std::function<graphStatus(Operator &)> OpDesc::GetInferFormatFunc() const { return infer_format_func_; }
  740. void OpDesc::AddInferFormatFunc(const std::function<graphStatus(Operator &)> &func) { infer_format_func_ = func; }
  741. void OpDesc::AddVerifierFunc(const std::function<graphStatus(Operator &)> &func) { verifier_func_ = func; }
  742. graphStatus OpDesc::InferShapeAndType() {
  743. if (infer_func_ == nullptr) {
  744. infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType());
  745. if (infer_func_ == nullptr) {
  746. GELOGW("%s does not have inferfunc_.", GetName().c_str());
  747. /// The infoshape function has not been added for each operator in the current operator information library.
  748. /// No infoshape added operator skips the call
  749. /// and directly uses the shape information passed down by the upper framework
  750. return GRAPH_SUCCESS;
  751. }
  752. }
  753. Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this());
  754. graphStatus ret = (graphStatus)infer_func_(op_proxy);
  755. op_proxy.BreakConnect();
  756. return ret;
  757. }
  758. graphStatus OpDesc::DefaultInferFormat() {
  759. ge::Format first_none_nd_format = FORMAT_ND;
  760. auto input_descs = GetAllInputsDescPtr();
  761. auto output_descs = GetAllOutputsDescPtr();
  762. // Overall input and output,get the first non-nd format
  763. for (const auto &input_desc : input_descs) {
  764. Format origin_format = input_desc->GetOriginFormat();
  765. if (origin_format != FORMAT_ND) {
  766. first_none_nd_format = origin_format;
  767. break;
  768. }
  769. }
  770. for (const auto &output_desc : output_descs) {
  771. Format origin_format = output_desc->GetOriginFormat();
  772. if (origin_format != FORMAT_ND) {
  773. first_none_nd_format = origin_format;
  774. break;
  775. }
  776. }
  777. // Refresh all input output format
  778. GELOGD("Default infer format.node[%s], first none nod format is:%d", GetName().c_str(), first_none_nd_format);
  779. for (const auto &input_desc : input_descs) {
  780. Format origin_format = input_desc->GetOriginFormat();
  781. GELOGD("Default infer format[in].node[%s].origin format is:%d", GetName().c_str(), origin_format);
  782. if (origin_format == FORMAT_ND) {
  783. input_desc->SetOriginFormat(first_none_nd_format);
  784. input_desc->SetFormat(first_none_nd_format);
  785. }
  786. }
  787. for (const auto &output_desc : output_descs) {
  788. Format origin_format = output_desc->GetOriginFormat();
  789. GELOGD("Default infer format[out].node[%s].origin format is:%d", GetName().c_str(), origin_format);
  790. if (origin_format == FORMAT_ND) {
  791. output_desc->SetOriginFormat(first_none_nd_format);
  792. output_desc->SetFormat(first_none_nd_format);
  793. }
  794. }
  795. return GRAPH_SUCCESS;
  796. }
  797. graphStatus OpDesc::OpVerify() {
  798. if (verifier_func_ == nullptr) {
  799. verifier_func_ = OperatorFactoryImpl::GetVerifyFunc(GetType());
  800. }
  801. if (verifier_func_ != nullptr) {
  802. Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this());
  803. graphStatus ret = (graphStatus)verifier_func_(op_proxy);
  804. op_proxy.BreakConnect();
  805. return ret;
  806. }
  807. return GRAPH_SUCCESS;
  808. }
  809. graphStatus OpDesc::CommonVerify() const {
  810. for (const string &iname : GetAllInputNames()) {
  811. // Checking shape of all inputs
  812. vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims();
  813. for (int64_t dim : ishape) {
  814. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(dim < -2,
  815. ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
  816. {GetName(), "input " + iname + " shape", "contains negative or zero dimension"});
  817. return GRAPH_FAILED,
  818. "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), iname.c_str());
  819. }
  820. }
  821. // Check all attributes defined
  822. const auto &all_attributes = GetAllAttrs();
  823. for (const auto &name : GetAllAttrNames()) {
  824. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(all_attributes.find(name) == all_attributes.end(),
  825. ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
  826. {GetName(), "attribute " + name, "is empty"});
  827. return GRAPH_FAILED,
  828. "operator attribute %s is empty.", name.c_str());
  829. }
  830. return GRAPH_SUCCESS;
  831. }
  832. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const {
  833. auto it = input_name_idx_.begin();
  834. for (; it != input_name_idx_.end(); ++it) {
  835. if (it->second == index) {
  836. break;
  837. }
  838. }
  839. GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), "");
  840. return it->first;
  841. }
  842. int OpDesc::GetInputIndexByName(const string &name) const {
  843. auto it_find = input_name_idx_.find(name);
  844. GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1);
  845. return static_cast<int>(it_find->second);
  846. }
  847. int OpDesc::GetValidInputIndexByName(const string &name) const {
  848. map<string, uint32_t> valid_input_name_idx{};
  849. uint32_t j = 0;
  850. for (size_t i = 0; i < GetAllInputsSize(); i++) {
  851. if (MutableInputDesc(static_cast<uint32_t>(i)) != nullptr) {
  852. auto valid_name = GetInputNameByIndex(static_cast<uint32_t>(i));
  853. GE_CHK_BOOL_RET_STATUS_NOLOG(!valid_name.empty(), -1);
  854. valid_input_name_idx.insert({valid_name, j});
  855. j++;
  856. }
  857. }
  858. auto it_find = valid_input_name_idx.find(name);
  859. GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != valid_input_name_idx.end(), -1);
  860. return static_cast<int>(it_find->second);
  861. }
  862. string OpDesc::GetValidInputNameByIndex(uint32_t index) const {
  863. map<string, uint32_t> valid_input_name_idx{};
  864. uint32_t j = 0;
  865. for (size_t i = 0; i < GetAllInputsSize(); i++) {
  866. if (MutableInputDesc(static_cast<uint32_t>(i)) != nullptr) {
  867. auto valid_name = GetInputNameByIndex(static_cast<uint32_t>(i));
  868. GE_CHK_BOOL_RET_STATUS_NOLOG(!valid_name.empty(), "");
  869. valid_input_name_idx.insert({valid_name, j});
  870. j++;
  871. }
  872. }
  873. auto it = valid_input_name_idx.begin();
  874. for (; it != valid_input_name_idx.end(); ++it) {
  875. if (it->second == index) {
  876. break;
  877. }
  878. }
  879. GE_CHK_BOOL_RET_STATUS_NOLOG(it != valid_input_name_idx.end(), "");
  880. return it->first;
  881. }
  882. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetOutputNameByIndex(uint32_t index) const {
  883. auto it = output_name_idx_.begin();
  884. for (; it != output_name_idx_.end(); ++it) {
  885. if (it->second == index) {
  886. break;
  887. }
  888. }
  889. GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), "");
  890. return it->first;
  891. }
  892. int OpDesc::GetOutputIndexByName(const string &name) const {
  893. auto it_find = output_name_idx_.find(name);
  894. GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != output_name_idx_.end(), -1);
  895. return static_cast<int>(it_find->second);
  896. }
  897. ProtoAttrMapHelper OpDesc::MutableAttrMap() {
  898. if (op_def_.GetProtoMsg() == nullptr) {
  899. GELOGE(GRAPH_FAILED, "op def get proto msg failed");
  900. return GeIrProtoHelper<ProtoAttrMap>();
  901. }
  902. return ProtoAttrMapHelper(op_def_.GetProtoOwner(), op_def_.GetProtoMsg()->mutable_attr());
  903. }
  904. ConstProtoAttrMapHelper OpDesc::GetAttrMap() const {
  905. return ConstProtoAttrMapHelper(op_def_.GetProtoOwner(), &op_def_.GetProtoMsg()->attr());
  906. }
  907. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetId(int64_t id) {
  908. auto proto_msg = op_def_.GetProtoMsg();
  909. if (proto_msg != nullptr) {
  910. proto_msg->set_id(id);
  911. }
  912. }
  913. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetId() const {
  914. auto proto_msg = op_def_.GetProtoMsg();
  915. if (proto_msg != nullptr) {
  916. return proto_msg->id();
  917. }
  918. return 0;
  919. }
  920. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetStreamId(int64_t stream_id) {
  921. auto proto_msg = op_def_.GetProtoMsg();
  922. if (proto_msg != nullptr) {
  923. proto_msg->set_stream_id(stream_id);
  924. }
  925. }
  926. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetStreamId() const {
  927. auto proto_msg = op_def_.GetProtoMsg();
  928. if (proto_msg != nullptr) {
  929. return proto_msg->stream_id();
  930. }
  931. return 0;
  932. }
  933. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputName(const vector<string> &input_name) {
  934. auto proto_msg = op_def_.GetProtoMsg();
  935. if (proto_msg != nullptr) {
  936. proto_msg->clear_input_name();
  937. for (auto &item : input_name) {
  938. proto_msg->add_input_name(item);
  939. }
  940. }
  941. }
  942. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetInputName() const {
  943. vector<string> input_name;
  944. auto proto_msg = op_def_.GetProtoMsg();
  945. if (proto_msg != nullptr) {
  946. for (auto &item : proto_msg->input_name()) {
  947. input_name.push_back(item);
  948. }
  949. }
  950. return input_name;
  951. }
  952. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcName(const vector<string> &src_name) {
  953. auto proto_msg = op_def_.GetProtoMsg();
  954. if (proto_msg != nullptr) {
  955. proto_msg->clear_src_name();
  956. for (auto &item : src_name) {
  957. proto_msg->add_src_name(item);
  958. }
  959. }
  960. }
  961. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetSrcName() const {
  962. vector<string> src_name;
  963. auto proto_msg = op_def_.GetProtoMsg();
  964. if (proto_msg != nullptr) {
  965. for (auto &item : proto_msg->src_name()) {
  966. src_name.push_back(item);
  967. }
  968. }
  969. return src_name;
  970. }
  971. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcIndex(const vector<int64_t> &src_index) {
  972. auto proto_msg = op_def_.GetProtoMsg();
  973. if (proto_msg != nullptr) {
  974. proto_msg->clear_src_index();
  975. for (auto &item : src_index) {
  976. proto_msg->add_src_index(item);
  977. }
  978. }
  979. }
  980. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetSrcIndex() const {
  981. vector<int64_t> src_index;
  982. auto proto_msg = op_def_.GetProtoMsg();
  983. if (proto_msg != nullptr) {
  984. for (auto &item : proto_msg->src_index()) {
  985. src_index.push_back(item);
  986. }
  987. }
  988. return src_index;
  989. }
  990. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputOffset(const vector<int64_t> &input) {
  991. auto proto_msg = op_def_.GetProtoMsg();
  992. if (proto_msg != nullptr) {
  993. proto_msg->clear_input_i();
  994. for (auto &item : input) {
  995. proto_msg->add_input_i(item);
  996. }
  997. }
  998. }
  999. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetInputOffset() const {
  1000. vector<int64_t> input;
  1001. auto proto_msg = op_def_.GetProtoMsg();
  1002. if (proto_msg != nullptr) {
  1003. for (auto &item : proto_msg->input_i()) {
  1004. input.push_back(item);
  1005. }
  1006. }
  1007. return input;
  1008. }
  1009. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOutputOffset(const vector<int64_t> &output) {
  1010. auto proto_msg = op_def_.GetProtoMsg();
  1011. if (proto_msg != nullptr) {
  1012. proto_msg->clear_output_i();
  1013. for (auto &item : output) {
  1014. proto_msg->add_output_i(item);
  1015. }
  1016. }
  1017. }
  1018. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetOutputOffset() const {
  1019. vector<int64_t> output;
  1020. auto proto_msg = op_def_.GetProtoMsg();
  1021. if (proto_msg != nullptr) {
  1022. for (auto &item : proto_msg->output_i()) {
  1023. output.push_back(item);
  1024. }
  1025. }
  1026. return output;
  1027. }
  1028. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstName(const vector<string> &dst_name) {
  1029. auto proto_msg = op_def_.GetProtoMsg();
  1030. if (proto_msg != nullptr) {
  1031. proto_msg->clear_dst_name();
  1032. for (auto &item : dst_name) {
  1033. proto_msg->add_dst_name(item);
  1034. }
  1035. }
  1036. }
  1037. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetDstName() const {
  1038. vector<string> dst_name;
  1039. auto proto_msg = op_def_.GetProtoMsg();
  1040. if (proto_msg != nullptr) {
  1041. for (auto &item : proto_msg->dst_name()) {
  1042. dst_name.push_back(item);
  1043. }
  1044. }
  1045. return dst_name;
  1046. }
  1047. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpInferDepends(const vector<string> &depend_names) {
  1048. auto ret = AttrUtils::SetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names);
  1049. if (ret != true) {
  1050. GELOGE(GRAPH_FAILED, "set op_infer_depends fail.");
  1051. }
  1052. }
  1053. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetOpInferDepends() const {
  1054. vector<string> depend_names;
  1055. (void)AttrUtils::GetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names);
  1056. return depend_names;
  1057. }
  1058. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector<int64_t> &dst_index) {
  1059. auto proto_msg = op_def_.GetProtoMsg();
  1060. if (proto_msg != nullptr) {
  1061. proto_msg->clear_dst_index();
  1062. for (auto &item : dst_index) {
  1063. proto_msg->add_dst_index(item);
  1064. }
  1065. }
  1066. }
  1067. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetDstIndex() const {
  1068. vector<int64_t> dst_index;
  1069. auto proto_msg = op_def_.GetProtoMsg();
  1070. if (proto_msg != nullptr) {
  1071. for (auto &item : proto_msg->dst_index()) {
  1072. dst_index.push_back(item);
  1073. }
  1074. }
  1075. return dst_index;
  1076. }
  1077. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspace(const vector<int64_t> &workspace) {
  1078. auto proto_msg = op_def_.GetProtoMsg();
  1079. if (proto_msg != nullptr) {
  1080. proto_msg->clear_workspace();
  1081. for (auto &item : workspace) {
  1082. proto_msg->add_workspace(item);
  1083. }
  1084. }
  1085. }
  1086. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetWorkspace() const {
  1087. vector<int64_t> workspace;
  1088. auto proto_msg = op_def_.GetProtoMsg();
  1089. if (proto_msg != nullptr) {
  1090. for (auto &item : proto_msg->workspace()) {
  1091. workspace.push_back(item);
  1092. }
  1093. }
  1094. return workspace;
  1095. }
  1096. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspaceBytes(const vector<int64_t> &workspace_bytes) {
  1097. auto proto_msg = op_def_.GetProtoMsg();
  1098. if (proto_msg != nullptr) {
  1099. proto_msg->clear_workspace_bytes();
  1100. for (auto &item : workspace_bytes) {
  1101. proto_msg->add_workspace_bytes(item);
  1102. }
  1103. }
  1104. }
  1105. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetWorkspaceBytes() const {
  1106. vector<int64_t> workspace_bytes;
  1107. auto proto_msg = op_def_.GetProtoMsg();
  1108. if (proto_msg != nullptr) {
  1109. for (auto &item : proto_msg->workspace_bytes()) {
  1110. workspace_bytes.push_back(item);
  1111. }
  1112. }
  1113. return workspace_bytes;
  1114. }
  1115. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetIsInputConst(const vector<bool> &is_input_const) {
  1116. auto proto_msg = op_def_.GetProtoMsg();
  1117. if (proto_msg != nullptr) {
  1118. proto_msg->clear_is_input_const();
  1119. for (auto item : is_input_const) {
  1120. proto_msg->add_is_input_const(item);
  1121. }
  1122. }
  1123. // If comes from ME,which is_input_const exist as attrs, outside no need to check GE_TRAIN flag
  1124. auto ret = AttrUtils::SetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const);
  1125. if (ret != true) {
  1126. GELOGE(GRAPH_FAILED, "set is_input_const fail.");
  1127. }
  1128. }
  1129. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<bool> OpDesc::GetIsInputConst() const {
  1130. vector<bool> is_input_const;
  1131. auto proto_msg = op_def_.GetProtoMsg();
  1132. if (proto_msg != nullptr) {
  1133. for (auto item : proto_msg->is_input_const()) {
  1134. is_input_const.push_back(item);
  1135. }
  1136. }
  1137. return is_input_const;
  1138. }
  1139. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name,
  1140. const int &index) {
  1141. if (input_name_idx_.find(name) != input_name_idx_.end()) {
  1142. GELOGI("Restore input name index is existed. name[%s]", name.c_str());
  1143. }
  1144. (void)input_name_idx_.insert(make_pair(name, index));
  1145. return GRAPH_SUCCESS;
  1146. }
  1147. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreOutputNameIdx(const string &name,
  1148. const int &index) {
  1149. if (output_name_idx_.find(name) != output_name_idx_.end()) {
  1150. GELOGI("Restore output name index is existed. name[%s]", name.c_str());
  1151. }
  1152. (void)output_name_idx_.insert(make_pair(name, index));
  1153. return GRAPH_SUCCESS;
  1154. }
  1155. graphStatus OpDesc::CallInferFunc(Operator &op) {
  1156. if (infer_func_ == nullptr) {
  1157. infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType());
  1158. if (infer_func_ == nullptr) {
  1159. GELOGW("%s does not have infer func.", GetName().c_str());
  1160. return GRAPH_PARAM_INVALID;
  1161. }
  1162. }
  1163. std::unique_ptr<NodeShapeTransUtils> transformer(new(std::nothrow) NodeShapeTransUtils(shared_from_this()));
  1164. if (transformer == nullptr) {
  1165. GELOGE(GRAPH_FAILED, "Memory alloc failed");
  1166. return GRAPH_FAILED;
  1167. }
  1168. if (!transformer->CatchFormatAndShape()) {
  1169. GELOGE(GRAPH_FAILED, "catch format and shape info failed!");
  1170. return GRAPH_FAILED;
  1171. }
  1172. graphStatus graph_status = (graphStatus)infer_func_(op);
  1173. if (graph_status != GRAPH_SUCCESS) {
  1174. GELOGE(GRAPH_FAILED, "%s call infer func. ret: %u", GetName().c_str(), graph_status);
  1175. return GRAPH_FAILED;
  1176. }
  1177. if (!transformer->UpdateFormatAndShape()) {
  1178. GELOGE(GRAPH_FAILED, "catch format and shape info failed!");
  1179. return GRAPH_FAILED;
  1180. }
  1181. return GRAPH_SUCCESS;
  1182. }
  1183. graphStatus OpDesc::CallInferFormatFunc(Operator &op) {
  1184. if (infer_format_func_ == nullptr) {
  1185. infer_format_func_ = OperatorFactoryImpl::GetInferFormatFunc(GetType());
  1186. if (infer_format_func_ == nullptr) {
  1187. return DefaultInferFormat();
  1188. }
  1189. }
  1190. return (graphStatus)infer_format_func_(op);
  1191. }
  1192. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const {
  1193. if (static_cast<size_t>(index) >= subgraph_instance_names_.size()) {
  1194. return "";
  1195. }
  1196. return subgraph_instance_names_.at(index);
  1197. }
  1198. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector<std::string> &OpDesc::GetSubgraphInstanceNames()
  1199. const {
  1200. return subgraph_instance_names_;
  1201. }
  1202. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) {
  1203. for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) {
  1204. if (*iter == name) {
  1205. *iter = "";
  1206. return;
  1207. }
  1208. }
  1209. }
  1210. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) {
  1211. GELOGI("Add subgraph name is %s", name.c_str());
  1212. auto iter = subgraph_names_to_index_.find(name);
  1213. if (iter != subgraph_names_to_index_.end()) {
  1214. GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second);
  1215. return GRAPH_FAILED;
  1216. }
  1217. auto size = subgraph_names_to_index_.size();
  1218. subgraph_names_to_index_[name] = size;
  1219. subgraph_instance_names_.resize(size + 1);
  1220. return GRAPH_SUCCESS;
  1221. }
  1222. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, uint32_t> & OpDesc::GetSubgraphNameIndexes()
  1223. const {
  1224. return subgraph_names_to_index_;
  1225. }
  1226. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  1227. graphStatus OpDesc::SetSubgraphInstanceName(uint32_t index, const std::string &name) {
  1228. GELOGI("Add sub graph instans name is %s, index is %u", name.c_str(), index);
  1229. if (index >= subgraph_instance_names_.size()) {
  1230. GE_LOGE("The index %u exceeds the max instance coutn %zu", index, subgraph_instance_names_.size());
  1231. return GRAPH_PARAM_INVALID;
  1232. }
  1233. subgraph_instance_names_[index] = name;
  1234. return GRAPH_SUCCESS;
  1235. }
  1236. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  1237. void OpDesc::RegisterSubgraphIrName(const string &name, SubgraphType type) {
  1238. subgraph_ir_names_to_type_[name] = type;
  1239. }
  1240. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  1241. const std::map<std::string, SubgraphType> &OpDesc::GetSubgraphIrNames() const {
  1242. return subgraph_ir_names_to_type_;
  1243. }
  1244. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  1245. SubgraphType OpDesc::GetSubgraphTypeByIrName(const std::string &name) const {
  1246. auto iter = subgraph_ir_names_to_type_.find(name);
  1247. if (iter == subgraph_ir_names_to_type_.end()) {
  1248. return kSubgraphTypeEnd;
  1249. }
  1250. return iter->second;
  1251. }
  1252. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  1253. graphStatus OpDesc::GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const {
  1254. for (size_t idx = 0; idx < subgraph_instance_names_.size(); ++idx) {
  1255. if (subgraph_instance_names_[idx] != instance_name) { // find subgraph index.
  1256. continue;
  1257. }
  1258. for (auto name_to_index : subgraph_names_to_index_) {
  1259. if (name_to_index.second != idx) { // find subgraph name.
  1260. continue;
  1261. }
  1262. subgraph_name = name_to_index.first;
  1263. return GRAPH_SUCCESS;
  1264. }
  1265. }
  1266. return GRAPH_PARAM_INVALID;
  1267. }
  1268. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::InferDataSlice() {
  1269. if (infer_data_slice_func_ == nullptr) {
  1270. infer_data_slice_func_ = OperatorFactoryImpl::GetInferDataSliceFunc(GetType());
  1271. if (infer_data_slice_func_ == nullptr) {
  1272. GELOGW("%s does not have infer data slice func.", GetName().c_str());
  1273. return NO_DEPENDENCE_FUNC;
  1274. }
  1275. }
  1276. Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this());
  1277. graphStatus ret = (graphStatus)infer_data_slice_func_(op_proxy);
  1278. op_proxy.BreakConnect();
  1279. return ret;
  1280. }
  1281. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示