You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_generator.cc 40 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "generator/ge_generator.h"
  17. #include <atomic>
  18. #include "common/ge/ge_util.h"
  19. #include "common/ge/plugin_manager.h"
  20. #include "common/helper/model_helper.h"
  21. #include "common/helper/om_file_helper.h"
  22. #include "common/util.h"
  23. #include "common/util/error_manager/error_manager.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "framework/common/debug/log.h"
  26. #include "ge/ge_api.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "graph/ge_context.h"
  29. #include "graph/manager/graph_manager.h"
  30. #include "graph/manager/util/rt_context_util.h"
  31. #include "graph/opsproto_manager.h"
  32. #include "graph/utils/graph_utils.h"
  33. #include "graph/utils/type_utils.h"
  34. #include "init/gelib.h"
  35. #include "model/ge_model.h"
  36. using std::map;
  37. using std::string;
  38. using std::vector;
  39. namespace {
  40. const char *const kAttrOpType = "op_type";
  41. const char *const kEngineNameDefault = "default";
  42. const char *const kVectorEngine = "VectorEngine";
  43. const char *const kAIcoreEngine = "AIcoreEngine";
  44. const char *const kFileNameSuffix = "online";
  45. const char *const kAicpuAllshape = "_AllShape";
  46. constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape";
  47. const int64_t kDynamicDimValue = -2;
  48. const int kDefaultDeviceId = 0;
  49. const int kDefaultJobId = 0;
  50. std::map<ge::OpEngineType, std::string> engine_type_map{
  51. {ge::ENGINE_SYS, kEngineNameDefault},
  52. {ge::ENGINE_AICORE, kAIcoreEngine},
  53. {ge::ENGINE_VECTOR, kVectorEngine}};
  54. bool ContainsDynamicInpus(const ge::OpDesc &op_desc) {
  55. for (auto &tensor_desc : op_desc.GetAllInputsDescPtr()) {
  56. if (tensor_desc->MutableShape().IsUnknownShape()) {
  57. GELOGI("Contains unknown shape input. set is_dynamic_input to true.");
  58. return true;
  59. }
  60. }
  61. return false;
  62. }
  63. } // namespace
  64. namespace ge {
  65. static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_type) {
  66. const OpDescPtr &op_desc = node->GetOpDesc();
  67. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  68. if (engine_type == ENGINE_SYS) {
  69. GELOGI("CheckEngineType: use default engine.");
  70. return SUCCESS;
  71. }
  72. // get op engine name
  73. string op_engine_name;
  74. auto iter = engine_type_map.find(engine_type);
  75. if (iter != engine_type_map.end()) {
  76. op_engine_name = iter->second;
  77. GELOGI("CheckEngineType: engine type: %d", static_cast<int>(engine_type));
  78. } else {
  79. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  80. {op_desc->GetName(), op_desc->GetType(), "engine type",
  81. "it only support default/AIcoreEngine/VectorEngine"});
  82. GELOGE(FAILED, "[Check][EngineType]value:%d not support, "
  83. "only support default/AIcoreEngine/VectorEngine now", static_cast<int>(engine_type));
  84. return FAILED;
  85. }
  86. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  87. op_desc->SetOpEngineName(op_engine_name);
  88. op_desc->SetOpKernelLibName(op_engine_name);
  89. return SUCCESS;
  90. }
  91. // set op engine name and opkernelLib. when engine support
  92. std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
  93. if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
  94. GELOGE(GE_CLI_GE_NOT_INITIALIZED, "CheckEngineType failed.");
  95. return FAILED;
  96. }
  97. OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
  98. std::vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  99. if (op_infos.empty()) {
  100. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  101. {op_desc->GetName(), op_desc->GetType(), "optype", "it can not find"});
  102. GELOGE(FAILED, "CheckEngineType: Can not get op info by op type %s", op_desc->GetType().c_str());
  103. return FAILED;
  104. }
  105. string kernel_name;
  106. for (const auto &it : op_infos) {
  107. if (it.engine == op_engine_name) {
  108. kernel_name = it.opKernelLib;
  109. break;
  110. }
  111. }
  112. if (kernel_name.empty()) {
  113. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  114. {op_desc->GetName(), op_desc->GetType(), "engine name" + FmtToStr(op_engine_name), "it can not find"});
  115. GELOGE(FAILED, "CheckEngineType:Can not find ops kernel, engine name: %s.", op_engine_name.c_str());
  116. return FAILED;
  117. }
  118. auto &kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  119. auto kernel_info_store = kernel_map.find(kernel_name);
  120. if (kernel_info_store != kernel_map.end()) {
  121. std::string unsupported_reason;
  122. if (kernel_info_store->second->CheckSupported(node, unsupported_reason)) {
  123. op_desc->SetOpEngineName(op_engine_name);
  124. op_desc->SetOpKernelLibName(kernel_name);
  125. GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(),
  126. op_engine_name.c_str(), op_desc->GetName().c_str());
  127. return SUCCESS;
  128. } else {
  129. ErrorManager::GetInstance().ATCReportErrMessage(
  130. "E13002", {"optype", "opskernel", "reason"}, {op_desc->GetType(), kernel_name, unsupported_reason});
  131. GELOGE(FAILED, "CheckEngineType: check support failed, Op type %s of ops kernel %s is unsupported, reason:%s",
  132. op_desc->GetType().c_str(), kernel_name.c_str(), unsupported_reason.c_str());
  133. return FAILED;
  134. }
  135. } else {
  136. ErrorManager::GetInstance().ATCReportErrMessage(
  137. "E13003", {"opname", "optype"}, {op_desc->GetName(), op_desc->GetType()});
  138. GELOGE(FAILED,
  139. "CheckEngineType:Can not find any supported ops kernel info store by kernel_name %s,"
  140. "op type is %s, op name is %s",
  141. kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str());
  142. }
  143. return FAILED;
  144. }
  145. static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index,
  146. bool attr, int32_t &data_index) {
  147. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  148. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  149. auto format = tensor.GetFormat();
  150. auto data_type = tensor.GetDataType();
  151. if (format == FORMAT_RESERVED && data_type == DT_UNDEFINED) {
  152. return SUCCESS;
  153. }
  154. string op_type;
  155. bool is_const = false;
  156. (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const);
  157. if (is_const) {
  158. GELOGD("Get input[%d] is const", index);
  159. op_type = CONSTANTOP;
  160. } else if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
  161. op_type = DATA;
  162. }
  163. string op_name = node->GetName() + "_in_" + std::to_string(index);
  164. OpDescPtr data_op = MakeShared<ge::OpDesc>(op_name, op_type);
  165. if (data_op == nullptr) {
  166. return FAILED;
  167. }
  168. if (is_const) {
  169. ConstGeTensorPtr tensor_value;
  170. if (!AttrUtils::GetTensor(tensor, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  171. GELOGE(FAILED, "Get value failed, node name:%s.", tensor.GetName().c_str());
  172. return FAILED;
  173. }
  174. if (!AttrUtils::SetTensor(data_op, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  175. GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS fail.");
  176. return FAILED;
  177. }
  178. }
  179. (void)AttrUtils::SetBool(data_op, "_is_single_op", true);
  180. GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  181. "[Add][InputDesc]fail for node:%s", data_op->GetName().c_str());
  182. GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  183. "[Add][OutputDesc]fail for node:%s", data_op->GetName().c_str());
  184. if (attr && !is_const) {
  185. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, data_index), return FAILED,
  186. "[Set][Attr:%s]fail for node:%s", ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str());
  187. ++data_index;
  188. }
  189. ge::NodePtr arg_node = graph->AddNode(data_op);
  190. GE_CHK_BOOL_EXEC(arg_node != nullptr, return FAILED, "Insert Data node fail");
  191. GE_CHK_STATUS(GraphUtils::AddEdge(arg_node->GetOutDataAnchor(0), node->GetInDataAnchor(index)),
  192. "[Add][Edge]fail from node:%s to node:%s", data_op->GetName().c_str(), node->GetName().c_str());
  193. return SUCCESS;
  194. }
  195. static Status AddOutputs(const ComputeGraphPtr &graph, const NodePtr &node, const vector<GeTensor> &outputs) {
  196. OpDescPtr op_desc = MakeShared<ge::OpDesc>(graph->GetName() + "_" + NODE_NAME_NET_OUTPUT, NETOUTPUT);
  197. if (op_desc == nullptr) {
  198. return FAILED;
  199. }
  200. (void)AttrUtils::SetBool(op_desc, "_is_single_op", true);
  201. int32_t count = 0;
  202. for (const auto &out_desc : outputs) {
  203. GeTensorDesc tensor = out_desc.GetTensorDesc();
  204. TensorUtils::SetInputTensor(tensor, true);
  205. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  206. "[Add][InputDesc]fail for node:%s", op_desc->GetName().c_str());
  207. TensorUtils::SetInputTensor(tensor, false);
  208. TensorUtils::SetOutputTensor(tensor, true);
  209. GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  210. "[Add][OutputDesc]fail for node:%s", op_desc->GetName().c_str());
  211. count++;
  212. }
  213. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  214. ge::NodePtr out_node = graph->AddNode(op_desc);
  215. GE_CHK_BOOL_EXEC(out_node != nullptr, return FAILED,
  216. "[Add][Node:%s]fail in graph:%u", op_desc->GetName().c_str(), graph->GetGraphID());
  217. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  218. for (int32_t i = 0; i < count; ++i) {
  219. GE_CHK_STATUS(GraphUtils::AddEdge(node->GetOutDataAnchor(i), out_node->GetInDataAnchor(i)),
  220. "[Add][Edge]fail from node:%s to node:%s", node->GetName().c_str(), out_node->GetName().c_str());
  221. }
  222. return SUCCESS;
  223. }
  224. static void GetOpsProtoPath(string &opsproto_path) {
  225. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  226. if (path_env != nullptr) {
  227. string path = path_env;
  228. string file_path = RealPath(path.c_str());
  229. if (file_path.empty()) {
  230. GELOGE(FAILED, "File path %s is invalid.", path.c_str());
  231. return;
  232. }
  233. opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
  234. GELOGI("Get opsproto so path from env : %s", path.c_str());
  235. return;
  236. }
  237. string path_base = PluginManager::GetPath();
  238. GELOGI("path_base is %s", path_base.c_str());
  239. path_base = path_base.substr(0, path_base.rfind('/'));
  240. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  241. opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
  242. }
  243. static Status ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) {
  244. for (auto input : inputs) {
  245. auto input_desc = input.GetTensorDesc();
  246. GeShape shape_ori = input_desc.GetShape();
  247. std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue};
  248. GeShape dynamic_shape(dynamic_shape_dims);
  249. std::vector<std::pair<int64_t, int64_t>> dynamic_shape_range;
  250. ge::GeTensor inputTensor;
  251. ge::GeTensorDesc desc(input_desc);
  252. bool is_const = false;
  253. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  254. if (!is_const) {
  255. int64_t storage_format = FORMAT_NCHW;
  256. if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_STORAGE_FORMAT, storage_format) &&
  257. !ge::AttrUtils::SetListInt(desc, ge::ATTR_NAME_STORAGE_SHAPE, dynamic_shape_dims)) {
  258. GELOGE(FAILED, "Set attr ATTR_NAME_STORAGE_SHAPE fail.");
  259. return FAILED;
  260. }
  261. desc.SetShape(dynamic_shape);
  262. desc.SetShapeRange(dynamic_shape_range);
  263. }
  264. inputTensor.SetTensorDesc(desc);
  265. inputs_dynamic.push_back(inputTensor);
  266. }
  267. return SUCCESS;
  268. }
  269. class GeGenerator::Impl {
  270. public:
  271. Impl(OmgContext &omg_context) : omg_context_(omg_context) {}
  272. ~Impl() = default;
  273. Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GeRootModelPtr &ge_models);
  274. Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model);
  275. Status SaveRootModel(const string &file_name_prefix, GeRootModelPtr &model, ModelBufferData &model_buff);
  276. Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  277. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
  278. Status GenerateInfershapeGraph(const Graph &graph);
  279. OmgContext &omg_context_;
  280. GraphManager graph_manager_;
  281. SaveParam save_param_;
  282. bool is_offline_ = true;
  283. bool is_singleop_unregistered_ = false;
  284. std::string build_mode_;
  285. std::string build_step_;
  286. static std::mutex mutex_;
  287. private:
  288. static std::string Trim(const std::string &str);
  289. bool ParseVersion(const std::string &line, std::string &version);
  290. bool GetVersionFromPath(const std::string &file_path, std::string &version);
  291. bool SetAtcVersionInfo(AttrHolder &obj);
  292. bool SetOppVersionInfo(AttrHolder &obj);
  293. bool SetOmSystemInfo(AttrHolder &obj);
  294. };
  295. Status GeGenerator::Initialize(const map<string, string> &options) {
  296. return Initialize(options, domi::GetContext());
  297. }
  298. Status GeGenerator::Initialize(const map<string, string> &options, OmgContext &omg_context) {
  299. impl_ = ge::MakeShared<Impl>(omg_context);
  300. if (impl_ == nullptr) {
  301. GELOGE(MEMALLOC_FAILED, "Make shared failed");
  302. return MEMALLOC_FAILED;
  303. }
  304. ErrorManager::GetInstance().SetStage(ErrorMessage::kInitialize, ErrorMessage::kOpsProtoInit);
  305. string opsproto_path;
  306. GetOpsProtoPath(opsproto_path);
  307. GELOGI("Get opsproto path is %s", opsproto_path.c_str());
  308. OpsProtoManager *manager = OpsProtoManager::Instance();
  309. map<string, string> option_tmp;
  310. option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
  311. (void)manager->Initialize(option_tmp);
  312. Status ret = impl_->graph_manager_.Initialize(options);
  313. if (ret != SUCCESS) {
  314. GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "Graph manager initialize failed.");
  315. return GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED;
  316. }
  317. // get ek file
  318. auto iter = options.find(EK_FILE);
  319. if (iter != options.end()) {
  320. impl_->save_param_.ek_file = iter->second;
  321. }
  322. // get cert file
  323. iter = options.find(CERT_FILE);
  324. if (iter != options.end()) {
  325. impl_->save_param_.cert_file = iter->second;
  326. }
  327. // get hw key file
  328. iter = options.find(HW_KEY_FILE);
  329. if (iter != options.end()) {
  330. impl_->save_param_.hw_key_file = iter->second;
  331. }
  332. // get private file
  333. iter = options.find(PRIVATE_KEY_FILE);
  334. if (iter != options.end()) {
  335. impl_->save_param_.pri_key_file = iter->second;
  336. }
  337. // get build mode
  338. iter = options.find(BUILD_MODE);
  339. if (iter != options.end()) {
  340. impl_->build_mode_ = iter->second;
  341. }
  342. // get build step
  343. iter = options.find(BUILD_STEP);
  344. if (iter != options.end()) {
  345. impl_->build_step_ = iter->second;
  346. }
  347. return SUCCESS;
  348. }
  349. Status GeGenerator::Finalize() {
  350. ErrorManager::GetInstance().SetStage(ErrorMessage::kFinalize, ErrorMessage::kFinalize);
  351. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  352. Status ret = impl_->graph_manager_.Finalize();
  353. if (ret != SUCCESS) {
  354. GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "Graph manager finalize failed.");
  355. return GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED;
  356. }
  357. return SUCCESS;
  358. }
  359. Status GeGenerator::GenerateOfflineModel(const Graph &graph, const string &file_name_prefix,
  360. const vector<GeTensor> &inputs) {
  361. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  362. GELOGI("Start to generate offline model.");
  363. ModelBufferData model;
  364. return GenerateModel(graph, file_name_prefix, inputs, model, true);
  365. }
  366. Status GeGenerator::GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ModelBufferData &model) {
  367. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  368. return GenerateModel(graph, "online", inputs, model, false);
  369. }
  370. Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) {
  371. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  372. Status ret = impl_->GenerateInfershapeGraph(graph);
  373. if (ret != SUCCESS) {
  374. GELOGE(ret, "Dump infershape json failed");
  375. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  376. GELOGE(FAILED, "graph_manager finalize fail.");
  377. }
  378. return ret;
  379. }
  380. GELOGI("Generate infer shape graph success");
  381. return SUCCESS;
  382. }
  383. std::mutex GeGenerator::Impl::mutex_;
  384. // Remove the space and tab before and after the string
  385. std::string GeGenerator::Impl::Trim(const std::string &str) {
  386. if (str.empty()) {
  387. return str;
  388. }
  389. std::string::size_type start = str.find_first_not_of(" \t\r\n");
  390. if (start == std::string::npos) {
  391. return str;
  392. }
  393. std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1;
  394. return str.substr(start, end);
  395. }
  396. // Parsing the command line
  397. bool GeGenerator::Impl::ParseVersion(const std::string &line, std::string &version) {
  398. std::string flag = "Version=";
  399. std::string temp = Trim(line);
  400. if (temp.empty()) {
  401. GELOGW("line is empty.");
  402. return false;
  403. }
  404. std::string::size_type pos = temp.find(flag);
  405. if (pos == std::string::npos) {
  406. GELOGW("Incorrect line [%s], it must include [%s].", line.c_str(), flag.c_str());
  407. return false;
  408. }
  409. if (temp.size() == flag.size()) {
  410. GELOGW("version information is empty. %s", line.c_str());
  411. return false;
  412. }
  413. version = temp.substr(pos + flag.size());
  414. return true;
  415. }
  416. bool GeGenerator::Impl::GetVersionFromPath(const std::string &file_path, std::string &version) {
  417. // Normalize the path
  418. string resolved_file_path = RealPath(file_path.c_str());
  419. if (resolved_file_path.empty()) {
  420. GELOGW("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str());
  421. return false;
  422. }
  423. std::ifstream fs(resolved_file_path, std::ifstream::in);
  424. if (!fs.is_open()) {
  425. GELOGW("Open %s failed.", file_path.c_str());
  426. return false;
  427. }
  428. std::string line;
  429. if (getline(fs, line)) {
  430. if (!ParseVersion(line, version)) {
  431. GELOGW("Parse version failed. content is [%s].", line.c_str());
  432. fs.close();
  433. return false;
  434. }
  435. } else {
  436. GELOGW("No version information found in the file path:%s", file_path.c_str());
  437. fs.close();
  438. return false;
  439. }
  440. fs.close(); // close the file
  441. return true;
  442. }
  443. // Set package version information in the model
  444. bool GeGenerator::Impl::SetAtcVersionInfo(AttrHolder &obj) {
  445. std::string path_base = ge::GELib::GetPath();
  446. path_base = path_base.substr(0, path_base.rfind('/'));
  447. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  448. std::string version_path = path_base + "version.info";
  449. std::string version;
  450. if (!GetVersionFromPath(version_path, version)) {
  451. GELOGW("Get atc version information failed!");
  452. return false;
  453. }
  454. // set version info
  455. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_ATC_VERSION, version)) {
  456. GELOGW("Ge model set atc version failed!");
  457. return false;
  458. }
  459. return true;
  460. }
  461. // Set package version information in the model
  462. bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) {
  463. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  464. if (path_env == nullptr) {
  465. GELOGW("Get environment variable ASCEND_OPP_PATH failed!");
  466. return false;
  467. }
  468. std::string version_path = path_env;
  469. version_path += "/version.info";
  470. std::string version;
  471. if (!GetVersionFromPath(version_path, version)) {
  472. GELOGW("Get opp version information failed!");
  473. return false;
  474. }
  475. // set version info
  476. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_OPP_VERSION, version)) {
  477. GELOGW("Ge model set opp version failed!");
  478. return false;
  479. }
  480. return true;
  481. }
  482. bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) {
  483. std::string soc_version;
  484. (void)ge::GetContext().GetOption(ge::SOC_VERSION, soc_version);
  485. GELOGI("SetOmSystemInfo soc_version: %s", soc_version.c_str());
  486. if (!ge::AttrUtils::SetStr(obj, "soc_version", soc_version)) {
  487. GELOGW("SetStr of soc_version failed.");
  488. return false;
  489. }
  490. std::string framework_type;
  491. (void)ge::GetContext().GetOption(ge::FRAMEWORK_TYPE, framework_type);
  492. GELOGI("SetOmSystemInfo framework_type: %s", framework_type.c_str());
  493. auto iter = ge::kFwkTypeToStr.find(framework_type);
  494. if (iter == ge::kFwkTypeToStr.end()) {
  495. GELOGW("Can not find framework_type in the map.");
  496. return false;
  497. }
  498. if (!ge::AttrUtils::SetStr(obj, "framework_type", iter->second)) {
  499. GELOGW("SetStr of framework_type failed.");
  500. return false;
  501. }
  502. return true;
  503. }
  504. Status GeGenerator::SetModelNameForDump(const GeRootModelPtr &ge_root_model) {
  505. bool is_unknown_shape = false;
  506. Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  507. if (ret != SUCCESS) {
  508. GELOGE(FAILED, "[Check][IsUnknownShape]Check root model is unknown shape failed, model id:%u",
  509. ge_root_model->GetModelId());
  510. REPORT_CALL_ERROR("E19999", "Check root model is unknown shape failed, model id:%u",
  511. ge_root_model->GetModelId());
  512. return FAILED;
  513. }
  514. GeModelPtr model_root = nullptr;
  515. if (is_unknown_shape) {
  516. model_root = MakeShared<GeModel>();
  517. GE_CHECK_NOTNULL(model_root);
  518. model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
  519. ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
  520. }
  521. ModelHelper model_helper;
  522. string model_name;
  523. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  524. Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
  525. model_name);
  526. if (name_ret != SUCCESS) {
  527. ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
  528. GELOGE(FAILED, "[Check][GetModelNameStep]Get model_name failed. Param --output is invalid, root graph name: %s",
  529. ge_root_model->GetRootGraph()->GetName().c_str());
  530. return PARAM_INVALID;
  531. }
  532. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  533. GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  534. GE_CHECK_NOTNULL(ge_model);
  535. ge_model->SetName(model_name);
  536. return SUCCESS;
  537. }
  538. Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
  539. ModelBufferData &model, bool is_offline) {
  540. rtContext_t ctx = nullptr;
  541. auto rt = rtCtxGetCurrent(&ctx);
  542. if (rt != RT_ERROR_NONE) {
  543. GELOGD("Current ctx is null.");
  544. ctx = nullptr;
  545. }
  546. GeRootModelPtr ge_root_model = nullptr;
  547. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  548. impl_->is_offline_ = is_offline;
  549. Status ret = impl_->BuildModel(graph, inputs, ge_root_model);
  550. if (ret != SUCCESS) {
  551. GELOGE(ret, "Build model failed.");
  552. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  553. GELOGE(FAILED, "graph_manager finalize fail.");
  554. }
  555. return ret;
  556. }
  557. /// BUILD_MODE_TUNING with BUILD_STEP_BEFORE_UB_MATCH no need save model;
  558. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER no need save model;
  559. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need save model.
  560. if ((impl_->build_mode_ == BUILD_MODE_TUNING) &&
  561. (impl_->build_step_ == BUILD_STEP_BEFORE_UB_MATCH || impl_->build_step_ == BUILD_STEP_AFTER_BUILDER ||
  562. impl_->build_step_ == BUILD_STEP_AFTER_BUILDER_SUB)) {
  563. GELOGI("Build mode:%s with step:%s no need SaveModel.",
  564. impl_->build_mode_.c_str(),
  565. impl_->build_step_.c_str());
  566. return SUCCESS;
  567. }
  568. GE_CHECK_NOTNULL(ge_root_model);
  569. ret = SetModelNameForDump(ge_root_model);
  570. if (ret != SUCCESS) {
  571. return ret;
  572. }
  573. ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
  574. if (ret != SUCCESS) {
  575. GELOGE(ret, "Save model failed");
  576. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  577. GELOGE(FAILED, "graph_manager finalize fail.");
  578. }
  579. return ret;
  580. }
  581. if (ctx != nullptr) {
  582. (void)rtCtxSetCurrent(ctx);
  583. }
  584. return SUCCESS;
  585. }
  586. namespace {
  587. bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tensor_desc) {
  588. bool is_need = true;
  589. // format and dtype is all reserved, stand for Optional input. When singleop scene
  590. if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) {
  591. is_need = false;
  592. }
  593. return is_need;
  594. }
  595. Status CheckDynamicSupport(GeModelPtr &ge_model, const ComputeGraphPtr &graph) {
  596. bool support_dynamic = true;
  597. bool is_dynamic = false;
  598. for (const auto &node : graph->GetDirectNode()) {
  599. GE_CHECK_NOTNULL(node);
  600. auto op_desc = node->GetOpDesc();
  601. GE_CHECK_NOTNULL(op_desc);
  602. if (op_desc->GetOpEngineName() != kAIcoreEngine) {
  603. continue;
  604. }
  605. if (AttrUtils::HasAttr(op_desc, kAttrSupportDynamicShape)) {
  606. is_dynamic = true;
  607. (void) AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic);
  608. if (!support_dynamic) {
  609. GELOGW("Node[%s] doesn't support dynamic shape.", node->GetName().c_str());
  610. break;
  611. }
  612. }
  613. }
  614. if (is_dynamic) {
  615. (void) AttrUtils::SetBool(ge_model, kAttrSupportDynamicShape, support_dynamic);
  616. }
  617. return SUCCESS;
  618. }
  619. }
  620. bool GeGenerator::CheckNoAicore(const ComputeGraphPtr &graph) {
  621. for (const auto &node : graph->GetDirectNode()) {
  622. if (node == nullptr) {
  623. continue;
  624. }
  625. auto op_desc = node->GetOpDesc();
  626. if (op_desc == nullptr) {
  627. continue;
  628. }
  629. if (op_desc->GetOpEngineName() == kAIcoreEngine) {
  630. return false;
  631. }
  632. }
  633. return true;
  634. }
  635. void GeGenerator::RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs) {
  636. for (auto &input : inputs) {
  637. GeTensorDesc input_desc = input.GetTensorDesc();
  638. bool is_const = false;
  639. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  640. if (!is_const) {
  641. outputs.emplace_back(input);
  642. }
  643. }
  644. }
  645. Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  646. const vector<GeTensor> &outputs) {
  647. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  648. if (!inputs.empty() && (inputs.size() != op_desc->GetAllInputsSize())) {
  649. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  650. {op_desc->GetName(), op_desc->GetType(), "inputs size" + FmtToStr(op_desc->GetAllInputsSize()),
  651. "tensor size is " + FmtToStr(inputs.size())});
  652. GELOGE(PARAM_INVALID, "Tensor size: %zu, Inputs size: %zu", inputs.size(), op_desc->GetAllInputsSize());
  653. return PARAM_INVALID;
  654. }
  655. if (!outputs.empty() && (outputs.size() != op_desc->GetOutputsSize())) {
  656. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  657. {op_desc->GetName(), op_desc->GetType(), "outputs size" + FmtToStr(op_desc->GetOutputsSize()),
  658. "tensor size is " + FmtToStr(outputs.size())});
  659. GELOGE(PARAM_INVALID, "Tensor size: %zu, Outputs size: %zu", outputs.size(), op_desc->GetOutputsSize());
  660. return PARAM_INVALID;
  661. }
  662. return SUCCESS;
  663. }
  664. Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
  665. const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
  666. bool is_offline) {
  667. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  668. impl_->is_offline_ = is_offline;
  669. if (!is_offline) {
  670. (void)AttrUtils::SetBool(op_desc, ATTR_SINGLE_OP_SCENE, true);
  671. }
  672. if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
  673. GELOGE(PARAM_INVALID, "input param is invalid when build single op!");
  674. return PARAM_INVALID;
  675. }
  676. OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_;
  677. omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc);
  678. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  679. impl_->is_singleop_unregistered_ = true;
  680. }
  681. // 0. Save original attributes.
  682. OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc);
  683. GE_CHECK_NOTNULL(op_desc_tmp);
  684. // 1. Create ComputeGraph.
  685. string name = ge::CurrentTimeInStr() + "_" + model_file_name;
  686. Graph graph;
  687. GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), "make graph fail.");
  688. // 2. check engine type when compile online
  689. if (model_file_name == kFileNameSuffix) {
  690. auto comp_graph = GraphUtils::GetComputeGraph(graph);
  691. GE_CHECK_NOTNULL(comp_graph);
  692. auto node = comp_graph->FindNode(op_desc->GetName());
  693. Status ret = CheckEngineTypeSupport(node, engine_type);
  694. if (ret != SUCCESS) {
  695. GELOGE(ret, "[Check][EngineType]value:%d for node:%s not support", engine_type, node->GetName().c_str());
  696. return ret;
  697. }
  698. }
  699. GELOGI("ATC parser success in single op build.");
  700. GeRootModelPtr ge_root_model = nullptr;
  701. vector<GeTensor> data_inputs;
  702. RemoveConst(inputs, data_inputs);
  703. GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, data_inputs, ge_root_model));
  704. map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs();
  705. GE_CHECK_NOTNULL(ge_root_model);
  706. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  707. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  708. if (name_to_ge_model.empty()) {
  709. GELOGE(PARAM_INVALID, "GetSubgraphInstanceNameToModel is empty.");
  710. return PARAM_INVALID;
  711. }
  712. const ComputeGraphPtr root_graph = ge_root_model->GetRootGraph();
  713. GeModelPtr &ge_model = name_to_ge_model.begin()->second;
  714. GE_CHK_STATUS_RET_NOLOG(CheckDynamicSupport(ge_model, root_graph));
  715. GELOGI("After build model, The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str());
  716. bool all_shape = false;
  717. (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape);
  718. if (all_shape && CheckNoAicore(root_graph)) {
  719. GELOGD("Get aicpu all_shape kernel!");
  720. vector<GeTensor> inputs_dynamic;
  721. vector<GeTensor> outputs_dynamic;
  722. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(inputs, inputs_dynamic));
  723. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(outputs, outputs_dynamic));
  724. GE_CHK_STATUS_RET_NOLOG(
  725. impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic));
  726. } else {
  727. GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
  728. }
  729. GELOGI("Start save GeModel to Model buffer");
  730. GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff));
  731. return SUCCESS;
  732. }
  733. /**
  734. * @ingroup ge
  735. * @brief Compiling a single operator into an offline model
  736. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  737. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  738. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  739. * @param [in] const string &model_file_name: Offline model filename.
  740. * @return SUCCESS handle successfully / others handle failed
  741. */
  742. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  743. const vector<GeTensor> &outputs, const string &model_file_name) {
  744. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  745. GELOGI("Start to build single op offline model, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  746. ModelBufferData model_buff;
  747. OpEngineType engine_type = ENGINE_SYS;
  748. Status status = BuildSingleOp(op_desc, inputs, outputs, model_file_name, engine_type, model_buff, true);
  749. GELOGI("Finish build single offline model, status: %u", status);
  750. return status;
  751. }
  752. /**
  753. * @ingroup ge
  754. * @brief Compiling a single operator into online buffer
  755. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  756. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  757. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  758. * @param [in] engine_type: specific engine.
  759. * @param [out] ModelBufferData &Model_buff: Model_buff: model buffer of the op.
  760. * @return SUCCESS handle successfully / others handle failed
  761. */
  762. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  763. const vector<GeTensor> &outputs, OpEngineType engine_type,
  764. ModelBufferData &model_buff) {
  765. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  766. GELOGI("Start to build single op online, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  767. Status status = BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false);
  768. GELOGI("Finish build single online model, status: %u", status);
  769. return status;
  770. }
  771. Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  772. const vector<GeTensor> &outputs, std::string graph_name, Graph &graph) {
  773. ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(graph_name);
  774. GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR);
  775. // 1. Add Node to ComputeGraph.
  776. NodePtr op_node = compute_graph->AddNode(op_desc);
  777. GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR);
  778. // 2. Create InputData node.
  779. int32_t arg_index = 0;
  780. int32_t data_index = 0;
  781. if (inputs.empty()) {
  782. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  783. GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
  784. if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
  785. continue;
  786. }
  787. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false, data_index));
  788. arg_index++;
  789. }
  790. } else {
  791. for (const auto &in_desc : inputs) {
  792. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true, data_index));
  793. arg_index++;
  794. }
  795. }
  796. // 3. Create Output node.
  797. if (!outputs.empty()) {
  798. GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs));
  799. }
  800. // dump ComputeGraph node.
  801. compute_graph->Dump();
  802. graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  803. return SUCCESS;
  804. }
  805. Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  806. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) {
  807. GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID);
  808. GE_CHK_BOOL_EXEC_NOLOG(graph_manager_.SaveParams(*ge_model, type, attrs, inputs, outputs) == SUCCESS,
  809. (void)graph_manager_.Finalize();
  810. return FAILED);
  811. return SUCCESS;
  812. }
  813. Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) {
  814. // set atc version
  815. if (!SetAtcVersionInfo(*(model.get()))) {
  816. GELOGW("SetPackageVersionInfo of atc failed!");
  817. }
  818. // set opp version
  819. if (!SetOppVersionInfo(*(model.get()))) {
  820. GELOGW("SetPackageVersionInfo of ops failed!");
  821. }
  822. ModelHelper model_helper;
  823. model_helper.SetSaveMode(is_offline_);
  824. Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff);
  825. if (ret != SUCCESS) {
  826. GELOGE(ret, "Save to om model failed");
  827. return ret;
  828. }
  829. return SUCCESS;
  830. }
  831. Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootModelPtr &ge_root_model,
  832. ModelBufferData &model_buff) {
  833. bool is_unknown_shape = false;
  834. auto ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  835. if (ret != SUCCESS) {
  836. GELOGE(FAILED, "Check root model is unkonwn shape failed");
  837. return FAILED;
  838. }
  839. GELOGD("begin save root model, cur model is unkonwn shape model ? : %d", is_unknown_shape);
  840. GE_CHK_BOOL_EXEC(!ge_root_model->GetSubgraphInstanceNameToModel().empty(), return FAILED,
  841. "ge root model has no sub model")
  842. GeModelPtr model_root = nullptr;
  843. if (is_unknown_shape) {
  844. auto name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  845. model_root = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  846. } else {
  847. model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second;
  848. }
  849. GE_CHECK_NOTNULL(model_root);
  850. // set atc version
  851. if (!SetAtcVersionInfo(*(model_root.get()))) {
  852. GELOGW("SetPackageVersionInfo of atc failed!");
  853. }
  854. // set opp version
  855. if (!SetOppVersionInfo(*(model_root.get()))) {
  856. GELOGW("SetPackageVersionInfo of ops failed!");
  857. }
  858. if (!SetOmSystemInfo(*(model_root.get()))) {
  859. GELOGW("SetOmsystemInfo failed!");
  860. }
  861. ModelHelper model_helper;
  862. model_helper.SetSaveMode(is_offline_);
  863. ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape);
  864. if (ret != SUCCESS) {
  865. GELOGE(ret, "Save to om model failed");
  866. return ret;
  867. }
  868. return SUCCESS;
  869. }
  870. Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs,
  871. GeRootModelPtr &ge_root_model) {
  872. static std::atomic<GraphId> atomic_graph_id(0);
  873. auto graph_id = atomic_graph_id.fetch_add(1);
  874. const std::map<std::string, std::string> options;
  875. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  876. if (ret != SUCCESS) {
  877. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph fail, graph id: %u", graph_id);
  878. (void)graph_manager_.Finalize();
  879. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  880. }
  881. graph_manager_.SetOptionsRunGraphFlag(false);
  882. static std::atomic<uint64_t> atomic_session_id(0);
  883. auto session_id = atomic_session_id.fetch_add(1);
  884. // This is a temporary add for graph with variable
  885. auto version = static_cast<int32_t>(SessionVersion::ClOUD_VERSION);
  886. ret = VarManager::Instance(session_id)->Init(version, session_id, kDefaultDeviceId, kDefaultJobId);
  887. GELOGI("Start init var instance, session_id %lu", session_id);
  888. if (ret != SUCCESS) {
  889. GELOGW("Failed init var instance, session_id %lu", session_id);
  890. }
  891. if (is_singleop_unregistered_) {
  892. ret = graph_manager_.BuildGraphForUnregisteredOp(graph_id, inputs, ge_root_model, session_id);
  893. } else {
  894. ret = graph_manager_.BuildGraph(graph_id, inputs, ge_root_model, session_id);
  895. }
  896. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  897. if (ret != SUCCESS) {
  898. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager build graph fail, graph id: %u", graph_id);
  899. VarManagerPool::Instance().RemoveVarManager(session_id);
  900. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  901. }
  902. VarManagerPool::Instance().RemoveVarManager(session_id);
  903. return SUCCESS;
  904. }
  905. Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph) {
  906. static std::atomic<GraphId> atomic_graph_id(0);
  907. auto graph_id = atomic_graph_id.fetch_add(1);
  908. const std::map<std::string, std::string> options;
  909. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  910. if (ret != SUCCESS) {
  911. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph failed, graph id: %u", graph_id);
  912. (void)graph_manager_.Finalize();
  913. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  914. }
  915. ret = graph_manager_.GenerateInfershapeGraph(graph_id);
  916. if (ret != SUCCESS) {
  917. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager generate graph failed");
  918. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  919. }
  920. return SUCCESS;
  921. }
  922. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示