You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_desc.cc 51 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/op_desc.h"
  17. #include "debug/ge_attr_define.h"
  18. #include "debug/ge_util.h"
  19. #include "external/graph/operator.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "common/util/error_manager/error_manager.h"
  22. #include "graph/ge_attr_value.h"
  23. #include "graph/ge_tensor.h"
  24. #include "graph/operator_factory_impl.h"
  25. #include "graph/utils/attr_utils.h"
  26. #include "graph/utils/ge_ir_utils.h"
  27. #include "graph/utils/op_desc_utils.h"
  28. #include "proto/ge_ir.pb.h"
  29. using std::make_pair;
  30. using std::shared_ptr;
  31. using std::string;
  32. using std::vector;
  33. /*lint -save -e521 -e681 -e732 -e737*/
  34. namespace ge {
  35. const std::string ATTR_NAME_ID = "id";
  36. const std::string ATTR_NAME_STREAM_ID = "stream_id";
  37. const std::string ATTR_NAME_INPUT_NAME = "input_name";
  38. const std::string ATTR_NAME_SRC_NAME = "src_name";
  39. const std::string ATTR_NAME_SRC_INDEX = "src_index";
  40. const std::string ATTR_NAME_INPUT = "input";
  41. const std::string ATTR_NAME_OUTPUT = "output";
  42. const std::string ATTR_NAME_INPUT_DESC = "input_desc";
  43. const std::string ATTR_NAME_OUTPUT_DESC = "output_desc";
  44. const std::string ATTR_NAME_DST_NAME = "dst_name";
  45. const std::string ATTR_NAME_DST_INDEX = "dst_index";
  46. const std::string ATTR_NAME_WORKSPACE = "workspace";
  47. const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes";
  48. const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const";
  49. const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends";
  50. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() {
  51. op_def_.InitDefault();
  52. if (op_def_.GetProtoMsg() != nullptr) {
  53. op_def_.GetProtoMsg()->set_has_out_attr(true);
  54. }
  55. }
  56. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::~OpDesc() {}
  57. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const std::string &name, const std::string &type) {
  58. op_def_.InitDefault();
  59. if (op_def_.GetProtoMsg() != nullptr) {
  60. op_def_.GetProtoMsg()->set_has_out_attr(true);
  61. }
  62. SetName(name);
  63. SetType(type);
  64. }
  65. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const ProtoMsgOwner &proto_msg_owner,
  66. ge::proto::OpDef *op_def)
  67. : op_def_(proto_msg_owner, op_def) {
  68. if (op_def != nullptr && !op_def->has_out_attr()) {
  69. op_def->set_has_out_attr(true);
  70. int64_t id = 0;
  71. (void)AttrUtils::GetInt(this, ATTR_NAME_ID, id);
  72. op_def->set_id(id);
  73. int64_t stream_id = 0;
  74. (void)AttrUtils::GetInt(this, ATTR_NAME_STREAM_ID, stream_id);
  75. op_def->set_stream_id(stream_id);
  76. vector<string> input_name;
  77. (void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME, input_name);
  78. for (auto &item : input_name) {
  79. op_def->add_input_name(item);
  80. }
  81. vector<string> src_name;
  82. (void)AttrUtils::GetListStr(this, ATTR_NAME_SRC_NAME, src_name);
  83. for (auto &item : src_name) {
  84. op_def->add_src_name(item);
  85. }
  86. vector<int64_t> src_index;
  87. (void)AttrUtils::GetListInt(this, ATTR_NAME_SRC_INDEX, src_index);
  88. for (auto &item : src_index) {
  89. op_def->add_src_index(item);
  90. }
  91. vector<int64_t> input;
  92. (void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT, input);
  93. for (auto &item : input) {
  94. op_def->add_input_i(item);
  95. }
  96. vector<int64_t> output;
  97. (void)AttrUtils::GetListInt(this, ATTR_NAME_OUTPUT, output);
  98. for (auto &item : output) {
  99. op_def->add_output_i(item);
  100. }
  101. vector<string> dst_name;
  102. (void)AttrUtils::GetListStr(this, ATTR_NAME_DST_NAME, dst_name);
  103. for (auto &item : dst_name) {
  104. op_def->add_dst_name(item);
  105. }
  106. vector<int64_t> dst_index;
  107. (void)AttrUtils::GetListInt(this, ATTR_NAME_DST_INDEX, dst_index);
  108. for (auto &item : dst_index) {
  109. op_def->add_dst_index(item);
  110. }
  111. vector<int64_t> workspace;
  112. (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE, workspace);
  113. for (auto &item : workspace) {
  114. op_def->add_workspace(item);
  115. }
  116. vector<int64_t> workspace_bytes;
  117. (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE_BYTES, workspace_bytes);
  118. for (auto &item : workspace_bytes) {
  119. op_def->add_workspace_bytes(item);
  120. }
  121. vector<bool> is_input_const;
  122. (void)AttrUtils::GetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const);
  123. for (auto item : is_input_const) {
  124. op_def->add_is_input_const(item);
  125. }
  126. auto input_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_INPUT_DESC].mutable_list();
  127. if (input_desc_mutable_list != nullptr) {
  128. *op_def->mutable_input_desc() = *(input_desc_mutable_list->mutable_td());
  129. }
  130. auto output_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_OUTPUT_DESC].mutable_list();
  131. if (output_desc_mutable_list != nullptr) {
  132. *op_def->mutable_output_desc() = *(output_desc_mutable_list->mutable_td());
  133. }
  134. }
  135. }
  136. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetName() const {
  137. auto proto_msg = op_def_.GetProtoMsg();
  138. if (proto_msg != nullptr) {
  139. return proto_msg->name();
  140. }
  141. return "";
  142. }
  143. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetName(const std::string &name) {
  144. auto proto_msg = op_def_.GetProtoMsg();
  145. if (proto_msg != nullptr) {
  146. proto_msg->set_name(name);
  147. }
  148. }
  149. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetType() const {
  150. auto proto_msg = op_def_.GetProtoMsg();
  151. if (proto_msg != nullptr) {
  152. return proto_msg->type();
  153. }
  154. return "";
  155. }
  156. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetType(const string &type) {
  157. auto proto_msg = op_def_.GetProtoMsg();
  158. if (proto_msg != nullptr) {
  159. proto_msg->set_type(type);
  160. }
  161. }
  162. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddInputDesc(const ge::GeTensorDesc &input_desc) {
  163. int index = static_cast<int>(inputs_desc_.size());
  164. return AddInputDesc("__input" + std::to_string(index), input_desc);
  165. }
  166. graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc) {
  167. graphStatus ret = GRAPH_SUCCESS;
  168. if (index < inputs_desc_.size()) {
  169. // InputsDesc[index] is exist, then update it
  170. ret = UpdateInputDesc(index, input_desc);
  171. } else {
  172. // InputDesc[index] is not exist, then add it
  173. ret = AddInputDesc(input_desc);
  174. }
  175. return ret;
  176. }
  177. graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) {
  178. if (input_name_idx_.find(name) != input_name_idx_.end()) {
  179. GELOGI("input %s is exist, update it", name.c_str());
  180. graphStatus ret = UpdateInputDesc(name, input_desc);
  181. return ret;
  182. } else {
  183. int index = static_cast<int>(inputs_desc_.size());
  184. std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(input_desc);
  185. if (in_desc == nullptr) {
  186. GELOGE(GRAPH_FAILED, "AddInputDesc failed, malloc shared_ptr failed.");
  187. return GRAPH_FAILED;
  188. }
  189. inputs_desc_.push_back(in_desc);
  190. (void)input_name_idx_.insert(make_pair(name, index));
  191. if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) {
  192. register_input_name_.push_back(name);
  193. }
  194. return GRAPH_SUCCESS;
  195. }
  196. }
  197. graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) {
  198. for (unsigned int i = 0; i < num; i++) {
  199. string input_name = name + std::to_string(i);
  200. GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED,
  201. "Add input tensor_desc is existed. name[%s]", input_name.c_str());
  202. std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc());
  203. if (in_desc == nullptr) {
  204. GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed.");
  205. return GRAPH_FAILED;
  206. }
  207. if (index > inputs_desc_.size()) {
  208. GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size.");
  209. return GRAPH_FAILED;
  210. }
  211. (void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc);
  212. // Update index in input_name_idx
  213. for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) {
  214. if (it->second >= (index + i)) {
  215. it->second += 1;
  216. }
  217. }
  218. (void)input_name_idx_.insert(make_pair(input_name, i + index));
  219. }
  220. return GRAPH_SUCCESS;
  221. }
  222. graphStatus OpDesc::AddOutputDescMiddle(const string &name, const unsigned int num, size_t index) {
  223. for (unsigned int i = 0; i < num; i++) {
  224. string output_name = name + std::to_string(i);
  225. GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED,
  226. "Add input tensor_desc is existed. name[%s]", output_name.c_str());
  227. std::shared_ptr<GeTensorDesc> out_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc());
  228. if (out_desc == nullptr) {
  229. GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed.");
  230. return GRAPH_FAILED;
  231. }
  232. if (index > outputs_desc_.size()) {
  233. GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size.");
  234. return GRAPH_FAILED;
  235. }
  236. (void)outputs_desc_.insert(outputs_desc_.begin() + index + i, out_desc);
  237. // Update index in input_name_idx
  238. for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) {
  239. if (it->second >= (index + i)) {
  240. it->second += 1;
  241. }
  242. }
  243. (void)output_name_idx_.insert(make_pair(output_name, i + index));
  244. }
  245. return GRAPH_SUCCESS;
  246. }
  247. graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) {
  248. for (unsigned int i = 0; i < num; i++) {
  249. string input_name = name + std::to_string(i);
  250. GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED,
  251. "Add input tensor_desc is existed. name[%s]", input_name.c_str());
  252. std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc());
  253. if (in_desc == nullptr) {
  254. GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed.");
  255. return GRAPH_FAILED;
  256. }
  257. (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc);
  258. // Update index in input_name_idx
  259. for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) {
  260. it->second += 1;
  261. }
  262. (void)input_name_idx_.insert(make_pair(input_name, 0));
  263. }
  264. return GRAPH_SUCCESS;
  265. }
  266. graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int num) {
  267. for (unsigned int i = 0; i < num; i++) {
  268. string output_name = name + std::to_string(i);
  269. GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED,
  270. "Add output tensor_desc is existed. name[%s]", output_name.c_str());
  271. std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc());
  272. if (in_desc == nullptr) {
  273. GELOGE(GRAPH_FAILED, "AddOutputDescForward failed, malloc shared_ptr failed.");
  274. return GRAPH_FAILED;
  275. }
  276. (void)outputs_desc_.insert(outputs_desc_.begin(), in_desc);
  277. // Update index in output_name_idx
  278. for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) {
  279. it->second += 1;
  280. }
  281. (void)output_name_idx_.insert(make_pair(output_name, 0));
  282. }
  283. return GRAPH_SUCCESS;
  284. }
  285. graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) {
  286. if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED;
  287. (void)optional_input_names_.insert(name);
  288. return GRAPH_SUCCESS;
  289. }
  290. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  291. OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) {
  292. GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index);
  293. inputs_desc_[index] = ComGraphMakeShared<GeTensorDesc>(tensor_Desc);
  294. if (inputs_desc_[index] == nullptr) {
  295. GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed.");
  296. return GRAPH_FAILED;
  297. }
  298. return GRAPH_SUCCESS;
  299. }
  300. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const {
  301. return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") &&
  302. IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") &&
  303. IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") &&
  304. IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") &&
  305. IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_"));
  306. }
  307. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const {
  308. const auto &op_def = this->op_def_.GetProtoMsg();
  309. const auto &r_op_def = r_op_desc.op_def_.GetProtoMsg();
  310. if ((op_def != nullptr) && (r_op_def != nullptr)) {
  311. // Message OpDef in ge_ir.proto
  312. return (
  313. IsEqual(op_def->name(), r_op_def->name(), "OpDef_.name()") &&
  314. IsEqual(op_def->type(), r_op_def->type(), "OpDef_.type()") &&
  315. IsEqual(ToString(op_def->input()), ToString(r_op_def->input()), "OpDef_.input()") &&
  316. IsEqual(op_def->has_out_attr(), r_op_def->has_out_attr(), "OpDef_.has_out_attr()") &&
  317. IsEqual(op_def->stream_id(), r_op_def->stream_id(), "OpDef_.stream_id()") &&
  318. IsEqual(ToString(op_def->input_name()), ToString(r_op_def->input_name()), "OpDef_.input_name()") &&
  319. IsEqual(ToString(op_def->src_name()), ToString(r_op_def->src_name()), "OpDef_.src_name()") &&
  320. IsEqual(ToString(op_def->dst_name()), ToString(r_op_def->dst_name()), "OpDef_.dst_name()") &&
  321. IsEqual(ToString(op_def->src_index()), ToString(r_op_def->src_index()), "OpDef_.src_index()") &&
  322. IsEqual(ToString(op_def->dst_index()), ToString(r_op_def->dst_index()), "OpDef_.dst_index()") &&
  323. IsEqual(ToString(op_def->input_i()), ToString(r_op_def->input_i()), "OpDef_.input_i()") &&
  324. IsEqual(ToString(op_def->output_i()), ToString(r_op_def->output_i()), "OpDef_.output_i()") &&
  325. IsEqual(ToString(op_def->workspace()), ToString(r_op_def->workspace()), "OpDef_.workspace()") &&
  326. IsEqual(ToString(op_def->workspace_bytes()), ToString(r_op_def->workspace_bytes()), "OpDef_.workspace_bytes()") &&
  327. IsEqual(ToString(op_def->is_input_const()), ToString(r_op_def->is_input_const()), "OpDef_.is_input_const()"));
  328. } else {
  329. return ((op_def == nullptr) && (r_op_def == nullptr));
  330. }
  331. }
  332. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescGenTensorDescsAreEqual(
  333. const OpDesc &r_op_desc) const {
  334. // 1.Verify inputs and outputs desc size
  335. const auto inputs_desc_size = this->inputs_desc_.size();
  336. const auto r_inputs_desc_size = r_op_desc.inputs_desc_.size();
  337. if (inputs_desc_size != r_inputs_desc_size) {
  338. GELOGE(GRAPH_FAILED, "Size of OpDesc's inputs desc verify failed, node name: %s.", this->GetName().c_str());
  339. return false;
  340. }
  341. const auto outputs_desc_size = this->outputs_desc_.size();
  342. const auto r_outputs_desc_size = r_op_desc.outputs_desc_.size();
  343. if (outputs_desc_size != r_outputs_desc_size) {
  344. GELOGE(GRAPH_FAILED, "Size of OpDesc's outputs desc verify failed, node name: %s.", this->GetName().c_str());
  345. return false;
  346. }
  347. // 2.Verify all inputs desc equal
  348. for (uint32_t i = 0; i < inputs_desc_size; i++) {
  349. const auto &in_ge_tensor_desc = this->GetInputDesc(i);
  350. const auto &r_in_ge_tensor_desc = r_op_desc.GetInputDesc(i);
  351. // Determine the connection relationship by GeTensorDesc
  352. if (!(in_ge_tensor_desc == r_in_ge_tensor_desc)) {
  353. GELOGE(GRAPH_FAILED, "Link info of OpDesc's inputs desc verify failed, OpDesc name: %s.",
  354. this->GetName().c_str());
  355. return false;
  356. }
  357. }
  358. // 3.Verify all outputs desc equal
  359. for (uint32_t i = 0; i < outputs_desc_size; i++) {
  360. const auto &out_ge_tensor_desc = this->GetOutputDesc(i);
  361. const auto &r_out_ge_tensor_desc = r_op_desc.GetOutputDesc(i);
  362. if (!(out_ge_tensor_desc == r_out_ge_tensor_desc)) {
  363. GELOGE(GRAPH_FAILED, "Link info of OpDesc's outputs desc verify failed, OpDesc name: %s.",
  364. this->GetName().c_str());
  365. return false;
  366. }
  367. }
  368. return true;
  369. }
  370. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpDesc &r_op_desc) const {
  371. return (OpDescAttrsAreEqual(r_op_desc) && OpDescMembersAreEqual(r_op_desc) &&
  372. OpDescGenTensorDescsAreEqual(r_op_desc));
  373. }
  374. graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) {
  375. auto it = input_name_idx_.find(name);
  376. if (it == input_name_idx_.end()) {
  377. GELOGW("Cann't find the input desc. name[%s]", name.c_str());
  378. return GRAPH_FAILED;
  379. }
  380. if (it->second >= inputs_desc_.size()) {
  381. GELOGE(GRAPH_FAILED, "[%d] more than size of inputs_desc_", it->second);
  382. return GRAPH_FAILED;
  383. }
  384. GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid.");
  385. return GRAPH_FAILED);
  386. inputs_desc_[it->second] = ComGraphMakeShared<GeTensorDesc>(tensor_Desc);
  387. if (inputs_desc_[it->second] == nullptr) {
  388. GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed.");
  389. return GRAPH_FAILED;
  390. }
  391. return GRAPH_SUCCESS;
  392. }
  393. bool OpDesc::InputIsSet(const string &name) const {
  394. auto it = input_name_idx_.find(name);
  395. if (it != input_name_idx_.end()) {
  396. GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false);
  397. auto tensor_desc = inputs_desc_[it->second];
  398. GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false);
  399. auto dims = tensor_desc->GetShape().GetDims();
  400. if (dims.size() > 0) {
  401. return true;
  402. }
  403. }
  404. return false;
  405. }
  406. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc(uint32_t index) const {
  407. GE_CHK_BOOL_RET_STATUS_NOLOG(index < inputs_desc_.size(), GeTensorDesc());
  408. return *(inputs_desc_[index].get());
  409. }
  410. GeTensorDesc OpDesc::GetInputDesc(const string &name) const {
  411. auto it = input_name_idx_.find(name);
  412. GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc());
  413. GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc());
  414. return *(inputs_desc_[it->second].get());
  415. }
  416. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const {
  417. GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index);
  418. if (inputs_desc_[index] == nullptr) {
  419. return nullptr;
  420. }
  421. if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) {
  422. GELOGW("input desc is invalid");
  423. return nullptr;
  424. }
  425. return inputs_desc_[index];
  426. }
  427. GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const {
  428. auto input_name_idx = GetAllInputName();
  429. auto it = input_name_idx.find(name);
  430. if (it == input_name_idx.end()) {
  431. GELOGW("Failed to get [%s] input desc", name.c_str());
  432. return nullptr;
  433. }
  434. return MutableInputDesc(it->second);
  435. }
  436. GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const {
  437. vector<string> names;
  438. if (input_name_idx_.empty()) {
  439. return OpDesc::Vistor<string>(shared_from_this(), names);
  440. }
  441. for (std::pair<string, uint32_t> input : input_name_idx_) {
  442. names.push_back(input.first);
  443. }
  444. return OpDesc::Vistor<string>(shared_from_this(), names);
  445. }
  446. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpKernelLibName(const std::string &name) {
  447. op_kernel_lib_name_ = name;
  448. }
  449. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpKernelLibName() const {
  450. return op_kernel_lib_name_;
  451. }
  452. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(const std::string &name) {
  453. engine_name_ = name;
  454. }
  455. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; }
  456. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllInputsDesc() const {
  457. vector<GeTensorDesc> temp{};
  458. for (const auto &it : inputs_desc_) {
  459. if (it->IsValid() == GRAPH_SUCCESS) {
  460. temp.push_back(*it);
  461. } else {
  462. GELOGW("this inputDesc is InValid, it won't be return");
  463. continue;
  464. }
  465. }
  466. return OpDesc::Vistor<GeTensorDesc>(shared_from_this(), temp);
  467. }
  468. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDescPtr> OpDesc::GetAllInputsDescPtr() const {
  469. vector<GeTensorDescPtr> temp{};
  470. for (const auto &it : inputs_desc_) {
  471. if (it->IsValid() == GRAPH_SUCCESS) {
  472. temp.push_back(it);
  473. } else {
  474. GELOGW("this inputDesc is InValid, it won't be return");
  475. continue;
  476. }
  477. }
  478. return OpDesc::Vistor<GeTensorDescPtr>(shared_from_this(), temp);
  479. }
  480. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() const {
  481. // Just return valid inputs size.InValid desc is set in default OPTION_INPUT register.
  482. size_t size = 0;
  483. for (const auto &it : inputs_desc_) {
  484. if (it->IsValid() == GRAPH_SUCCESS) {
  485. size++;
  486. }
  487. }
  488. return size;
  489. }
  490. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); }
  491. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) {
  492. int index = static_cast<int>(outputs_desc_.size());
  493. return AddOutputDesc("__output" + std::to_string(index), output_desc);
  494. }
  495. graphStatus OpDesc::AddOutputDesc(const string &name, const ge::GeTensorDesc &output_desc) {
  496. GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(name) == output_name_idx_.end()), GRAPH_FAILED,
  497. "Add output tensor_Desc is existed. name[%s]", name.c_str());
  498. int index = static_cast<int>(outputs_desc_.size());
  499. std::shared_ptr<GeTensorDesc> tensor = ComGraphMakeShared<GeTensorDesc>(output_desc);
  500. if (tensor == nullptr) {
  501. GELOGE(GRAPH_FAILED, "AddOutputDesc failed, malloc shared_ptr failed.");
  502. return GRAPH_FAILED;
  503. }
  504. outputs_desc_.push_back(tensor);
  505. (void)output_name_idx_.insert(make_pair(name, index));
  506. if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) {
  507. register_output_name_.push_back(name);
  508. }
  509. return GRAPH_SUCCESS;
  510. }
  511. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  512. OpDesc::UpdateOutputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) {
  513. GE_CHK_BOOL_RET_STATUS((index < outputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index);
  514. outputs_desc_[index] = ComGraphMakeShared<GeTensorDesc>(tensor_Desc);
  515. if (outputs_desc_[index] == nullptr) {
  516. GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed.");
  517. return GRAPH_FAILED;
  518. }
  519. return GRAPH_SUCCESS;
  520. }
  521. graphStatus OpDesc::UpdateOutputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) {
  522. auto it = output_name_idx_.find(name);
  523. if (it == output_name_idx_.end()) {
  524. GELOGW("Cann't find the output desc. name[%s]", name.c_str());
  525. return GRAPH_FAILED;
  526. }
  527. GE_IF_BOOL_EXEC(it->second >= outputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid.");
  528. return GRAPH_FAILED);
  529. outputs_desc_[it->second] = ComGraphMakeShared<GeTensorDesc>(tensor_Desc);
  530. if (outputs_desc_[it->second] == nullptr) {
  531. GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed.");
  532. return GRAPH_FAILED;
  533. }
  534. return GRAPH_SUCCESS;
  535. }
  536. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetOutputDesc(uint32_t index) const {
  537. GE_CHK_BOOL_RET_STATUS_NOLOG(index < outputs_desc_.size(), GeTensorDesc());
  538. return *(outputs_desc_[index].get());
  539. }
  540. GeTensorDesc OpDesc::GetOutputDesc(const string &name) const {
  541. auto it = output_name_idx_.find(name);
  542. GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), GeTensorDesc());
  543. GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < outputs_desc_.size(), GeTensorDesc());
  544. return *(outputs_desc_[it->second].get());
  545. }
  546. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(uint32_t index) const {
  547. GE_CHK_BOOL_RET_STATUS(index < outputs_desc_.size(), nullptr, "Cann't find the output desc %u", index);
  548. return outputs_desc_[index];
  549. }
  550. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const {
  551. auto it = output_name_idx_.find(name);
  552. if (it == output_name_idx_.end()) {
  553. GELOGW("Failed to get [%s] output desc", name.c_str());
  554. return nullptr;
  555. }
  556. return MutableOutputDesc(it->second);
  557. }
  558. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const {
  559. return static_cast<uint32_t>(outputs_desc_.size());
  560. }
  561. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllOutputsDesc() const {
  562. vector<GeTensorDesc> temp{};
  563. for (const auto &it : outputs_desc_) {
  564. temp.push_back(*it);
  565. }
  566. return OpDesc::Vistor<GeTensorDesc>(shared_from_this(), temp);
  567. }
  568. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDescPtr> OpDesc::GetAllOutputsDescPtr() const {
  569. return OpDesc::Vistor<GeTensorDescPtr>(shared_from_this(), outputs_desc_);
  570. }
  571. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetOutputsSize() const { return outputs_desc_.size(); }
  572. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetOutputDescPtr(uint32_t index) const {
  573. GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast<uint32_t>(outputs_desc_.size()), nullptr);
  574. return outputs_desc_[index];
  575. }
  576. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(uint32_t index) const {
  577. GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast<uint32_t>(inputs_desc_.size()), nullptr);
  578. if (inputs_desc_[index] == nullptr) {
  579. return nullptr;
  580. }
  581. if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) {
  582. GELOGW("inputsDesc[%u] is InValid", index);
  583. return nullptr;
  584. } else {
  585. return inputs_desc_[static_cast<size_t>(index)];
  586. }
  587. }
  588. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr
  589. OpDesc::GetInputDescPtrDfault(uint32_t index) const {
  590. GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr);
  591. return inputs_desc_[(int32_t)index];
  592. }
  593. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const {
  594. auto it = input_name_idx_.find(name);
  595. GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), shared_ptr<const GeTensorDesc>());
  596. return inputs_desc_[it->second];
  597. }
  598. graphStatus OpDesc::AddRegisterInputName(const std::string &name) {
  599. if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) {
  600. register_input_name_.push_back(name);
  601. }
  602. return GRAPH_SUCCESS;
  603. }
  604. vector<string> OpDesc::GetRegisterInputName() const { return register_input_name_; }
  605. graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) {
  606. if (is_push_back) {
  607. for (unsigned int i = 0; i < num; i++) {
  608. if (AddInputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) return GRAPH_FAILED;
  609. }
  610. } else {
  611. if (AddInputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED;
  612. }
  613. if (AddRegisterInputName(name) != GRAPH_SUCCESS) {
  614. return GRAPH_FAILED;
  615. }
  616. return GRAPH_SUCCESS;
  617. }
  618. graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index) {
  619. if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) {
  620. return GRAPH_FAILED;
  621. }
  622. return GRAPH_SUCCESS;
  623. }
  624. graphStatus OpDesc::AddRegisterOutputName(const string &name) {
  625. if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) {
  626. register_output_name_.push_back(name);
  627. }
  628. return GRAPH_SUCCESS;
  629. }
  630. vector<string> OpDesc::GetRegisterOutputName() const { return register_output_name_; }
  631. graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) {
  632. if (is_push_back) {
  633. for (unsigned int i = 0; i < num; i++) {
  634. if (AddOutputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) return GRAPH_FAILED;
  635. }
  636. } else {
  637. if (AddOutputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED;
  638. }
  639. if (AddRegisterOutputName(name) != GRAPH_SUCCESS) {
  640. return GRAPH_FAILED;
  641. }
  642. return GRAPH_SUCCESS;
  643. }
  644. bool OpDesc::IsOptionalInput(const string &name) const {
  645. return optional_input_names_.find(name) != optional_input_names_.end();
  646. }
  647. bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); }
  648. std::map<string, uint32_t> OpDesc::GetAllInputName() const { return input_name_idx_; }
  649. std::map<string, uint32_t> OpDesc::GetAllOutputName() { return output_name_idx_; }
  650. bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) {
  651. bool ret = true;
  652. // Use inputDesc_.size() to contain the InValid OptionInput.GetInputsSize() will remove default OptionInput name.
  653. auto input_map_size = inputs_desc_.size();
  654. auto factory_map_size = input_name_idx.size();
  655. // It indicates that some inputs have no optionalname.
  656. // The redundant optionalname of factory needs to be deleted and then assigned
  657. if (input_map_size < factory_map_size) {
  658. GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size,
  659. factory_map_size);
  660. for (auto it = input_name_idx.begin(); it != input_name_idx.end();) {
  661. if (it->second >= input_map_size) {
  662. it = input_name_idx.erase(it);
  663. } else {
  664. ++it;
  665. }
  666. }
  667. if (input_name_idx.size() == input_map_size) {
  668. GELOGI("UpdateInputName");
  669. input_name_idx_ = input_name_idx;
  670. } else {
  671. ret = false;
  672. GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size());
  673. }
  674. } else if (input_map_size == factory_map_size) {
  675. input_name_idx_ = input_name_idx;
  676. } else {
  677. ret = false;
  678. GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size);
  679. }
  680. return ret;
  681. }
  682. bool OpDesc::UpdateOutputName(std::map<string, uint32_t> output_name_idx) {
  683. size_t output_map_size = GetAllOutputsDescSize();
  684. size_t factory_map_size = output_name_idx.size();
  685. if (output_map_size < factory_map_size) {
  686. GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size,
  687. factory_map_size);
  688. for (auto it = output_name_idx.begin(); it != output_name_idx.end();) {
  689. if (it->second >= output_map_size) {
  690. it = output_name_idx.erase(it);
  691. } else {
  692. ++it;
  693. }
  694. }
  695. if (output_name_idx.size() == output_map_size) {
  696. GELOGI("UpdateoutputName");
  697. output_name_idx_ = output_name_idx;
  698. return true;
  699. }
  700. } else if (output_map_size == factory_map_size) {
  701. output_name_idx_ = output_name_idx;
  702. return true;
  703. } else {
  704. GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size);
  705. return false;
  706. }
  707. GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size);
  708. return false;
  709. }
  710. std::function<graphStatus(Operator &)> OpDesc::GetInferFunc() const { return infer_func_; }
  711. std::function<graphStatus(Operator &)> OpDesc::GetVerifyFunc() const { return verifier_func_; }
  712. void OpDesc::AddInferFunc(const std::function<graphStatus(Operator &)> &func) { infer_func_ = func; }
  713. std::function<graphStatus(Operator &)> OpDesc::GetInferFormatFunc() const { return infer_format_func_; }
  714. void OpDesc::AddInferFormatFunc(const std::function<graphStatus(Operator &)> &func) { infer_format_func_ = func; }
  715. void OpDesc::AddVerifierFunc(const std::function<graphStatus(Operator &)> &func) { verifier_func_ = func; }
  716. graphStatus OpDesc::InferShapeAndType() {
  717. if (infer_func_ == nullptr) {
  718. infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType());
  719. if (infer_func_ == nullptr) {
  720. GELOGW("%s does not have inferfunc_.", GetName().c_str());
  721. /// The infoshape function has not been added for each operator in the current operator information library.
  722. /// No infoshape added operator skips the call
  723. /// and directly uses the shape information passed down by the upper framework
  724. return GRAPH_SUCCESS;
  725. }
  726. }
  727. Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this());
  728. graphStatus ret = (graphStatus)infer_func_(op_proxy);
  729. op_proxy.BreakConnect();
  730. return ret;
  731. }
  732. graphStatus OpDesc::DefaultInferFormat() {
  733. ge::Format first_none_nd_format = FORMAT_ND;
  734. auto input_descs = GetAllInputsDescPtr();
  735. auto output_descs = GetAllOutputsDescPtr();
  736. // Overall input and output,get the first non-nd format
  737. for (const auto &input_desc : input_descs) {
  738. Format origin_format = input_desc->GetOriginFormat();
  739. if (origin_format != FORMAT_ND) {
  740. first_none_nd_format = origin_format;
  741. break;
  742. }
  743. }
  744. for (const auto &output_desc : output_descs) {
  745. Format origin_format = output_desc->GetOriginFormat();
  746. if (origin_format != FORMAT_ND) {
  747. first_none_nd_format = origin_format;
  748. break;
  749. }
  750. }
  751. // Refresh all input output format
  752. GELOGD("Default infer format.node[%s], first none nod format is:%d", GetName().c_str(), first_none_nd_format);
  753. for (const auto &input_desc : input_descs) {
  754. Format origin_format = input_desc->GetOriginFormat();
  755. GELOGD("Default infer format[in].node[%s].origin format is:%d", GetName().c_str(), origin_format);
  756. if (origin_format == FORMAT_ND) {
  757. input_desc->SetOriginFormat(first_none_nd_format);
  758. input_desc->SetFormat(first_none_nd_format);
  759. }
  760. }
  761. for (const auto &output_desc : output_descs) {
  762. Format origin_format = output_desc->GetOriginFormat();
  763. GELOGD("Default infer format[out].node[%s].origin format is:%d", GetName().c_str(), origin_format);
  764. if (origin_format == FORMAT_ND) {
  765. output_desc->SetOriginFormat(first_none_nd_format);
  766. output_desc->SetFormat(first_none_nd_format);
  767. }
  768. }
  769. return GRAPH_SUCCESS;
  770. }
  771. graphStatus OpDesc::OpVerify() {
  772. if (verifier_func_ == nullptr) {
  773. verifier_func_ = OperatorFactoryImpl::GetVerifyFunc(GetType());
  774. }
  775. if (verifier_func_ != nullptr) {
  776. Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this());
  777. graphStatus ret = (graphStatus)verifier_func_(op_proxy);
  778. op_proxy.BreakConnect();
  779. return ret;
  780. }
  781. return GRAPH_SUCCESS;
  782. }
  783. graphStatus OpDesc::CommonVerify() const {
  784. for (const string &iname : GetAllInputNames()) {
  785. // Checking shape of all inputs
  786. vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims();
  787. for (int64_t dim : ishape) {
  788. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  789. dim < -2, ErrorManager::GetInstance().ATCReportErrMessage(
  790. "E19014", {"opname", "value", "reason"},
  791. {GetName(), "input " + iname + " shape", "contains negative or zero dimension"});
  792. return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(),
  793. iname.c_str());
  794. }
  795. }
  796. // Check all attributes defined
  797. const auto &all_attributes = GetAllAttrs();
  798. for (const auto &name : GetAllAttrNames()) {
  799. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  800. all_attributes.find(name) == all_attributes.end(),
  801. ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
  802. {GetName(), "attribute " + name, "is empty"});
  803. return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str());
  804. }
  805. return GRAPH_SUCCESS;
  806. }
  807. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const {
  808. auto it = input_name_idx_.begin();
  809. for (; it != input_name_idx_.end(); ++it) {
  810. if (it->second == index) {
  811. break;
  812. }
  813. }
  814. GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), "");
  815. return it->first;
  816. }
  817. int OpDesc::GetInputIndexByName(const string &name) const {
  818. auto it_find = input_name_idx_.find(name);
  819. GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1);
  820. return static_cast<int>(it_find->second);
  821. }
  822. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetOutputNameByIndex(uint32_t index) const {
  823. auto it = output_name_idx_.begin();
  824. for (; it != output_name_idx_.end(); ++it) {
  825. if (it->second == index) {
  826. break;
  827. }
  828. }
  829. GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), "");
  830. return it->first;
  831. }
  832. int OpDesc::GetOutputIndexByName(const string &name) const {
  833. auto it_find = output_name_idx_.find(name);
  834. GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != output_name_idx_.end(), -1);
  835. return static_cast<int>(it_find->second);
  836. }
  837. ProtoAttrMapHelper OpDesc::MutableAttrMap() {
  838. if (op_def_.GetProtoMsg() == nullptr) {
  839. GELOGE(GRAPH_FAILED, "op def get proto msg failed");
  840. return GeIrProtoHelper<ProtoAttrMap>();
  841. }
  842. return ProtoAttrMapHelper(op_def_.GetProtoOwner(), op_def_.GetProtoMsg()->mutable_attr());
  843. }
  844. ConstProtoAttrMapHelper OpDesc::GetAttrMap() const {
  845. return ConstProtoAttrMapHelper(op_def_.GetProtoOwner(), &op_def_.GetProtoMsg()->attr());
  846. }
  847. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetId(int64_t id) {
  848. auto proto_msg = op_def_.GetProtoMsg();
  849. if (proto_msg != nullptr) {
  850. proto_msg->set_id(id);
  851. }
  852. }
  853. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetId() const {
  854. auto proto_msg = op_def_.GetProtoMsg();
  855. if (proto_msg != nullptr) {
  856. return proto_msg->id();
  857. }
  858. return 0;
  859. }
  860. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetStreamId(int64_t stream_id) {
  861. auto proto_msg = op_def_.GetProtoMsg();
  862. if (proto_msg != nullptr) {
  863. proto_msg->set_stream_id(stream_id);
  864. }
  865. }
  866. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetStreamId() const {
  867. auto proto_msg = op_def_.GetProtoMsg();
  868. if (proto_msg != nullptr) {
  869. return proto_msg->stream_id();
  870. }
  871. return 0;
  872. }
  873. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputName(const vector<string> &input_name) {
  874. auto proto_msg = op_def_.GetProtoMsg();
  875. if (proto_msg != nullptr) {
  876. proto_msg->clear_input_name();
  877. for (auto &item : input_name) {
  878. proto_msg->add_input_name(item);
  879. }
  880. }
  881. }
  882. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetInputName() const {
  883. vector<string> input_name;
  884. auto proto_msg = op_def_.GetProtoMsg();
  885. if (proto_msg != nullptr) {
  886. for (auto &item : proto_msg->input_name()) {
  887. input_name.push_back(item);
  888. }
  889. }
  890. return input_name;
  891. }
  892. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcName(const vector<string> &src_name) {
  893. auto proto_msg = op_def_.GetProtoMsg();
  894. if (proto_msg != nullptr) {
  895. proto_msg->clear_src_name();
  896. for (auto &item : src_name) {
  897. proto_msg->add_src_name(item);
  898. }
  899. }
  900. }
  901. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetSrcName() const {
  902. vector<string> src_name;
  903. auto proto_msg = op_def_.GetProtoMsg();
  904. if (proto_msg != nullptr) {
  905. for (auto &item : proto_msg->src_name()) {
  906. src_name.push_back(item);
  907. }
  908. }
  909. return src_name;
  910. }
  911. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcIndex(const vector<int64_t> &src_index) {
  912. auto proto_msg = op_def_.GetProtoMsg();
  913. if (proto_msg != nullptr) {
  914. proto_msg->clear_src_index();
  915. for (auto &item : src_index) {
  916. proto_msg->add_src_index(item);
  917. }
  918. }
  919. }
  920. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetSrcIndex() const {
  921. vector<int64_t> src_index;
  922. auto proto_msg = op_def_.GetProtoMsg();
  923. if (proto_msg != nullptr) {
  924. for (auto &item : proto_msg->src_index()) {
  925. src_index.push_back(item);
  926. }
  927. }
  928. return src_index;
  929. }
  930. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputOffset(const vector<int64_t> &input) {
  931. auto proto_msg = op_def_.GetProtoMsg();
  932. if (proto_msg != nullptr) {
  933. proto_msg->clear_input_i();
  934. for (auto &item : input) {
  935. proto_msg->add_input_i(item);
  936. }
  937. }
  938. }
  939. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetInputOffset() const {
  940. vector<int64_t> input;
  941. auto proto_msg = op_def_.GetProtoMsg();
  942. if (proto_msg != nullptr) {
  943. for (auto &item : proto_msg->input_i()) {
  944. input.push_back(item);
  945. }
  946. }
  947. return input;
  948. }
  949. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOutputOffset(const vector<int64_t> &output) {
  950. auto proto_msg = op_def_.GetProtoMsg();
  951. if (proto_msg != nullptr) {
  952. proto_msg->clear_output_i();
  953. for (auto &item : output) {
  954. proto_msg->add_output_i(item);
  955. }
  956. }
  957. }
  958. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetOutputOffset() const {
  959. vector<int64_t> output;
  960. auto proto_msg = op_def_.GetProtoMsg();
  961. if (proto_msg != nullptr) {
  962. for (auto &item : proto_msg->output_i()) {
  963. output.push_back(item);
  964. }
  965. }
  966. return output;
  967. }
  968. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstName(const vector<string> &dst_name) {
  969. auto proto_msg = op_def_.GetProtoMsg();
  970. if (proto_msg != nullptr) {
  971. proto_msg->clear_dst_name();
  972. for (auto &item : dst_name) {
  973. proto_msg->add_dst_name(item);
  974. }
  975. }
  976. }
  977. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetDstName() const {
  978. vector<string> dst_name;
  979. auto proto_msg = op_def_.GetProtoMsg();
  980. if (proto_msg != nullptr) {
  981. for (auto &item : proto_msg->dst_name()) {
  982. dst_name.push_back(item);
  983. }
  984. }
  985. return dst_name;
  986. }
  987. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpInferDepends(const vector<string> &depend_names) {
  988. auto ret = AttrUtils::SetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names);
  989. if (ret != true) {
  990. GELOGE(GRAPH_FAILED, "set op_infer_depends fail.");
  991. }
  992. }
  993. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetOpInferDepends() const {
  994. vector<string> depend_names;
  995. (void)AttrUtils::GetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names);
  996. return depend_names;
  997. }
  998. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector<int64_t> &dst_index) {
  999. auto proto_msg = op_def_.GetProtoMsg();
  1000. if (proto_msg != nullptr) {
  1001. proto_msg->clear_dst_index();
  1002. for (auto &item : dst_index) {
  1003. proto_msg->add_dst_index(item);
  1004. }
  1005. }
  1006. }
  1007. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetDstIndex() const {
  1008. vector<int64_t> dst_index;
  1009. auto proto_msg = op_def_.GetProtoMsg();
  1010. if (proto_msg != nullptr) {
  1011. for (auto &item : proto_msg->dst_index()) {
  1012. dst_index.push_back(item);
  1013. }
  1014. }
  1015. return dst_index;
  1016. }
  1017. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspace(const vector<int64_t> &workspace) {
  1018. auto proto_msg = op_def_.GetProtoMsg();
  1019. if (proto_msg != nullptr) {
  1020. proto_msg->clear_workspace();
  1021. for (auto &item : workspace) {
  1022. proto_msg->add_workspace(item);
  1023. }
  1024. }
  1025. }
  1026. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetWorkspace() const {
  1027. vector<int64_t> workspace;
  1028. auto proto_msg = op_def_.GetProtoMsg();
  1029. if (proto_msg != nullptr) {
  1030. for (auto &item : proto_msg->workspace()) {
  1031. workspace.push_back(item);
  1032. }
  1033. }
  1034. return workspace;
  1035. }
  1036. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspaceBytes(const vector<int64_t> &workspace_bytes) {
  1037. auto proto_msg = op_def_.GetProtoMsg();
  1038. if (proto_msg != nullptr) {
  1039. proto_msg->clear_workspace_bytes();
  1040. for (auto &item : workspace_bytes) {
  1041. proto_msg->add_workspace_bytes(item);
  1042. }
  1043. }
  1044. }
  1045. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<int64_t> OpDesc::GetWorkspaceBytes() const {
  1046. vector<int64_t> workspace_bytes;
  1047. auto proto_msg = op_def_.GetProtoMsg();
  1048. if (proto_msg != nullptr) {
  1049. for (auto &item : proto_msg->workspace_bytes()) {
  1050. workspace_bytes.push_back(item);
  1051. }
  1052. }
  1053. return workspace_bytes;
  1054. }
  1055. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetIsInputConst(const vector<bool> &is_input_const) {
  1056. auto proto_msg = op_def_.GetProtoMsg();
  1057. if (proto_msg != nullptr) {
  1058. proto_msg->clear_is_input_const();
  1059. for (auto item : is_input_const) {
  1060. proto_msg->add_is_input_const(item);
  1061. }
  1062. }
  1063. // If comes from ME,which is_input_const exist as attrs, outside no need to check GE_TRAIN flag
  1064. auto ret = AttrUtils::SetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const);
  1065. if (ret != true) {
  1066. GELOGE(GRAPH_FAILED, "set is_input_const fail.");
  1067. }
  1068. }
  1069. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<bool> OpDesc::GetIsInputConst() const {
  1070. vector<bool> is_input_const;
  1071. auto proto_msg = op_def_.GetProtoMsg();
  1072. if (proto_msg != nullptr) {
  1073. for (auto item : proto_msg->is_input_const()) {
  1074. is_input_const.push_back(item);
  1075. }
  1076. }
  1077. return is_input_const;
  1078. }
  1079. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name,
  1080. const int &index) {
  1081. if (input_name_idx_.find(name) != input_name_idx_.end()) {
  1082. GELOGI("Restore input name index is existed. name[%s]", name.c_str());
  1083. }
  1084. (void)input_name_idx_.insert(make_pair(name, index));
  1085. return GRAPH_SUCCESS;
  1086. }
  1087. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreOutputNameIdx(const string &name,
  1088. const int &index) {
  1089. if (output_name_idx_.find(name) != output_name_idx_.end()) {
  1090. GELOGI("Restore output name index is existed. name[%s]", name.c_str());
  1091. }
  1092. (void)output_name_idx_.insert(make_pair(name, index));
  1093. return GRAPH_SUCCESS;
  1094. }
  1095. graphStatus OpDesc::CallInferFunc(Operator &op) {
  1096. if (infer_func_ == nullptr) {
  1097. infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType());
  1098. if (infer_func_ == nullptr) {
  1099. GELOGW("%s does not have infer func.", GetName().c_str());
  1100. return GRAPH_PARAM_INVALID;
  1101. }
  1102. }
  1103. graphStatus graph_status = (graphStatus)infer_func_(op);
  1104. if (graph_status != GRAPH_SUCCESS) {
  1105. GELOGE(GRAPH_FAILED, "%s call infer func. ret: %u", GetName().c_str(), graph_status);
  1106. return GRAPH_FAILED;
  1107. }
  1108. return GRAPH_SUCCESS;
  1109. }
  1110. graphStatus OpDesc::CallInferFormatFunc(Operator &op) {
  1111. if (infer_format_func_ == nullptr) {
  1112. infer_format_func_ = OperatorFactoryImpl::GetInferFormatFunc(GetType());
  1113. if (infer_format_func_ == nullptr) {
  1114. return DefaultInferFormat();
  1115. }
  1116. }
  1117. return (graphStatus)infer_format_func_(op);
  1118. }
  1119. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const {
  1120. if (static_cast<size_t>(index) >= subgraph_instance_names_.size()) {
  1121. return "";
  1122. }
  1123. return subgraph_instance_names_.at(index);
  1124. }
  1125. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector<std::string> &OpDesc::GetSubgraphInstanceNames()
  1126. const {
  1127. return subgraph_instance_names_;
  1128. }
  1129. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) {
  1130. for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) {
  1131. if (*iter == name) {
  1132. *iter = "";
  1133. return;
  1134. }
  1135. }
  1136. }
  1137. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) {
  1138. GELOGI("Add subgraph name is %s", name.c_str());
  1139. auto iter = subgraph_names_to_index_.find(name);
  1140. if (iter != subgraph_names_to_index_.end()) {
  1141. GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second);
  1142. return GRAPH_FAILED;
  1143. }
  1144. auto size = subgraph_names_to_index_.size();
  1145. subgraph_names_to_index_[name] = size;
  1146. subgraph_instance_names_.resize(size + 1);
  1147. return GRAPH_SUCCESS;
  1148. }
  1149. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, uint32_t> &OpDesc::GetSubgraphNameIndexes()
  1150. const {
  1151. return subgraph_names_to_index_;
  1152. }
  1153. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::SetSubgraphInstanceName(uint32_t index,
  1154. const std::string &name) {
  1155. GELOGI("Add sub graph instans name is %s, index is %u", name.c_str(), index);
  1156. if (index >= subgraph_instance_names_.size()) {
  1157. GE_LOGE("The index %u exceeds the max instance coutn %zu", index, subgraph_instance_names_.size());
  1158. return GRAPH_PARAM_INVALID;
  1159. }
  1160. subgraph_instance_names_[index] = name;
  1161. return GRAPH_SUCCESS;
  1162. }
  1163. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RegisterSubgraphIrName(const string &name,
  1164. SubgraphType type) {
  1165. subgraph_ir_names_to_type_[name] = type;
  1166. }
  1167. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, SubgraphType> &OpDesc::GetSubgraphIrNames()
  1168. const {
  1169. return subgraph_ir_names_to_type_;
  1170. }
  1171. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY SubgraphType
  1172. OpDesc::GetSubgraphTypeByIrName(const std::string &name) const {
  1173. auto iter = subgraph_ir_names_to_type_.find(name);
  1174. if (iter == subgraph_ir_names_to_type_.end()) {
  1175. return kSubgraphTypeEnd;
  1176. }
  1177. return iter->second;
  1178. }
  1179. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  1180. OpDesc::GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const {
  1181. for (size_t idx = 0; idx < subgraph_instance_names_.size(); ++idx) {
  1182. if (subgraph_instance_names_[idx] != instance_name) { // find subgraph index.
  1183. continue;
  1184. }
  1185. for (auto name_to_index : subgraph_names_to_index_) {
  1186. if (name_to_index.second != idx) { // find subgraph name.
  1187. continue;
  1188. }
  1189. subgraph_name = name_to_index.first;
  1190. return GRAPH_SUCCESS;
  1191. }
  1192. }
  1193. return GRAPH_PARAM_INVALID;
  1194. }
  1195. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示