You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_generator.cc 41 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "generator/ge_generator.h"
  17. #include <atomic>
  18. #include "common/ge/ge_util.h"
  19. #include "common/ge/plugin_manager.h"
  20. #include "common/helper/model_helper.h"
  21. #include "common/helper/om_file_helper.h"
  22. #include "common/util.h"
  23. #include "common/util/error_manager/error_manager.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "framework/common/debug/log.h"
  26. #include "ge/ge_api.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "graph/ge_context.h"
  29. #include "graph/manager/graph_manager.h"
  30. #include "graph/manager/util/rt_context_util.h"
  31. #include "graph/opsproto_manager.h"
  32. #include "graph/utils/graph_utils.h"
  33. #include "graph/utils/type_utils.h"
  34. #include "init/gelib.h"
  35. #include "model/ge_model.h"
  36. #include "analyzer/analyzer.h"
  37. using std::map;
  38. using std::string;
  39. using std::vector;
  40. namespace {
  41. const char *const kAttrOpType = "op_type";
  42. const char *const kEngineNameDefault = "default";
  43. const char *const kVectorEngine = "VectorEngine";
  44. const char *const kAIcoreEngine = "AIcoreEngine";
  45. const char *const kFileNameSuffix = "online";
  46. const char *const kAicpuAllshape = "_AllShape";
  47. constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape";
  48. const int64_t kDynamicDimValue = -2;
  49. const int kDefaultDeviceId = 0;
  50. const int kDefaultJobId = 0;
  51. std::map<ge::OpEngineType, std::string> engine_type_map{
  52. {ge::ENGINE_SYS, kEngineNameDefault},
  53. {ge::ENGINE_AICORE, kAIcoreEngine},
  54. {ge::ENGINE_VECTOR, kVectorEngine}};
  55. bool ContainsDynamicInpus(const ge::OpDesc &op_desc) {
  56. for (auto &tensor_desc : op_desc.GetAllInputsDescPtr()) {
  57. if (tensor_desc->MutableShape().IsUnknownShape()) {
  58. GELOGI("Contains unknown shape input. set is_dynamic_input to true.");
  59. return true;
  60. }
  61. }
  62. return false;
  63. }
  64. } // namespace
  65. namespace ge {
  66. static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_type) {
  67. const OpDescPtr &op_desc = node->GetOpDesc();
  68. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  69. if (engine_type == ENGINE_SYS) {
  70. GELOGI("CheckEngineType: use default engine.");
  71. return SUCCESS;
  72. }
  73. // get op engine name
  74. string op_engine_name;
  75. auto iter = engine_type_map.find(engine_type);
  76. if (iter != engine_type_map.end()) {
  77. op_engine_name = iter->second;
  78. GELOGI("CheckEngineType: engine type: %d", static_cast<int>(engine_type));
  79. } else {
  80. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  81. {op_desc->GetName(), op_desc->GetType(), "engine type",
  82. "it only support default/AIcoreEngine/VectorEngine"});
  83. GELOGE(FAILED, "[Check][EngineType]value:%d not support, "
  84. "only support default/AIcoreEngine/VectorEngine now", static_cast<int>(engine_type));
  85. return FAILED;
  86. }
  87. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  88. op_desc->SetOpEngineName(op_engine_name);
  89. op_desc->SetOpKernelLibName(op_engine_name);
  90. return SUCCESS;
  91. }
  92. // set op engine name and opkernelLib. when engine support
  93. std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
  94. if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
  95. GELOGE(GE_CLI_GE_NOT_INITIALIZED, "CheckEngineType failed.");
  96. return FAILED;
  97. }
  98. OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
  99. std::vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  100. if (op_infos.empty()) {
  101. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  102. {op_desc->GetName(), op_desc->GetType(), "optype", "it can not find"});
  103. GELOGE(FAILED, "CheckEngineType: Can not get op info by op type %s", op_desc->GetType().c_str());
  104. return FAILED;
  105. }
  106. string kernel_name;
  107. for (const auto &it : op_infos) {
  108. if (it.engine == op_engine_name) {
  109. kernel_name = it.opKernelLib;
  110. break;
  111. }
  112. }
  113. if (kernel_name.empty()) {
  114. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  115. {op_desc->GetName(), op_desc->GetType(), "engine name" + FmtToStr(op_engine_name), "it can not find"});
  116. GELOGE(FAILED, "CheckEngineType:Can not find ops kernel, engine name: %s.", op_engine_name.c_str());
  117. return FAILED;
  118. }
  119. auto &kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  120. auto kernel_info_store = kernel_map.find(kernel_name);
  121. if (kernel_info_store != kernel_map.end()) {
  122. std::string unsupported_reason;
  123. if (kernel_info_store->second->CheckSupported(node, unsupported_reason)) {
  124. op_desc->SetOpEngineName(op_engine_name);
  125. op_desc->SetOpKernelLibName(kernel_name);
  126. GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(),
  127. op_engine_name.c_str(), op_desc->GetName().c_str());
  128. return SUCCESS;
  129. } else {
  130. ErrorManager::GetInstance().ATCReportErrMessage(
  131. "E13002", {"optype", "opskernel", "reason"}, {op_desc->GetType(), kernel_name, unsupported_reason});
  132. GELOGE(FAILED, "CheckEngineType: check support failed, Op type %s of ops kernel %s is unsupported, reason:%s",
  133. op_desc->GetType().c_str(), kernel_name.c_str(), unsupported_reason.c_str());
  134. return FAILED;
  135. }
  136. } else {
  137. ErrorManager::GetInstance().ATCReportErrMessage(
  138. "E13003", {"opname", "optype"}, {op_desc->GetName(), op_desc->GetType()});
  139. GELOGE(FAILED,
  140. "CheckEngineType:Can not find any supported ops kernel info store by kernel_name %s,"
  141. "op type is %s, op name is %s",
  142. kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str());
  143. }
  144. return FAILED;
  145. }
  146. static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index,
  147. bool attr, int32_t &data_index) {
  148. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  149. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  150. auto format = tensor.GetFormat();
  151. auto data_type = tensor.GetDataType();
  152. if (format == FORMAT_RESERVED && data_type == DT_UNDEFINED) {
  153. return SUCCESS;
  154. }
  155. string op_type;
  156. bool is_const = false;
  157. (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const);
  158. if (is_const) {
  159. GELOGD("Get input[%d] is const", index);
  160. op_type = CONSTANTOP;
  161. } else if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
  162. op_type = DATA;
  163. }
  164. string op_name = node->GetName() + "_in_" + std::to_string(index);
  165. OpDescPtr data_op = MakeShared<ge::OpDesc>(op_name, op_type);
  166. if (data_op == nullptr) {
  167. return FAILED;
  168. }
  169. if (is_const) {
  170. ConstGeTensorPtr tensor_value;
  171. if (!AttrUtils::GetTensor(tensor, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  172. GELOGE(FAILED, "Get value failed, node name:%s.", tensor.GetName().c_str());
  173. return FAILED;
  174. }
  175. if (!AttrUtils::SetTensor(data_op, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  176. GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS fail.");
  177. return FAILED;
  178. }
  179. }
  180. (void)AttrUtils::SetBool(data_op, "_is_single_op", true);
  181. GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  182. "[Add][InputDesc]fail for node:%s", data_op->GetName().c_str());
  183. GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  184. "[Add][OutputDesc]fail for node:%s", data_op->GetName().c_str());
  185. if (attr && !is_const) {
  186. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, data_index), return FAILED,
  187. "[Set][Attr:%s]fail for node:%s", ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str());
  188. ++data_index;
  189. }
  190. ge::NodePtr arg_node = graph->AddNode(data_op);
  191. GE_CHK_BOOL_EXEC(arg_node != nullptr, return FAILED, "Insert Data node fail");
  192. GE_CHK_STATUS(GraphUtils::AddEdge(arg_node->GetOutDataAnchor(0), node->GetInDataAnchor(index)),
  193. "[Add][Edge]fail from node:%s to node:%s", data_op->GetName().c_str(), node->GetName().c_str());
  194. return SUCCESS;
  195. }
  196. static Status AddOutputs(const ComputeGraphPtr &graph, const NodePtr &node, const vector<GeTensor> &outputs) {
  197. OpDescPtr op_desc = MakeShared<ge::OpDesc>(graph->GetName() + "_" + NODE_NAME_NET_OUTPUT, NETOUTPUT);
  198. if (op_desc == nullptr) {
  199. return FAILED;
  200. }
  201. (void)AttrUtils::SetBool(op_desc, "_is_single_op", true);
  202. int32_t count = 0;
  203. for (const auto &out_desc : outputs) {
  204. GeTensorDesc tensor = out_desc.GetTensorDesc();
  205. TensorUtils::SetInputTensor(tensor, true);
  206. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  207. "[Add][InputDesc]fail for node:%s", op_desc->GetName().c_str());
  208. TensorUtils::SetInputTensor(tensor, false);
  209. TensorUtils::SetOutputTensor(tensor, true);
  210. GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  211. "[Add][OutputDesc]fail for node:%s", op_desc->GetName().c_str());
  212. count++;
  213. }
  214. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  215. ge::NodePtr out_node = graph->AddNode(op_desc);
  216. GE_CHK_BOOL_EXEC(out_node != nullptr, return FAILED,
  217. "[Add][Node:%s]fail in graph:%u", op_desc->GetName().c_str(), graph->GetGraphID());
  218. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  219. for (int32_t i = 0; i < count; ++i) {
  220. GE_CHK_STATUS(GraphUtils::AddEdge(node->GetOutDataAnchor(i), out_node->GetInDataAnchor(i)),
  221. "[Add][Edge]fail from node:%s to node:%s", node->GetName().c_str(), out_node->GetName().c_str());
  222. }
  223. return SUCCESS;
  224. }
  225. static void GetOpsProtoPath(string &opsproto_path) {
  226. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  227. if (path_env != nullptr) {
  228. string path = path_env;
  229. string file_path = RealPath(path.c_str());
  230. if (file_path.empty()) {
  231. GELOGE(FAILED, "File path %s is invalid.", path.c_str());
  232. return;
  233. }
  234. opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
  235. GELOGI("Get opsproto so path from env : %s", path.c_str());
  236. return;
  237. }
  238. string path_base = PluginManager::GetPath();
  239. GELOGI("path_base is %s", path_base.c_str());
  240. path_base = path_base.substr(0, path_base.rfind('/'));
  241. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  242. opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
  243. }
  244. static Status ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) {
  245. for (auto input : inputs) {
  246. auto input_desc = input.GetTensorDesc();
  247. GeShape shape_ori = input_desc.GetShape();
  248. std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue};
  249. GeShape dynamic_shape(dynamic_shape_dims);
  250. std::vector<std::pair<int64_t, int64_t>> dynamic_shape_range;
  251. ge::GeTensor inputTensor;
  252. ge::GeTensorDesc desc(input_desc);
  253. bool is_const = false;
  254. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  255. if (!is_const) {
  256. int64_t storage_format = FORMAT_NCHW;
  257. if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_STORAGE_FORMAT, storage_format) &&
  258. !ge::AttrUtils::SetListInt(desc, ge::ATTR_NAME_STORAGE_SHAPE, dynamic_shape_dims)) {
  259. GELOGE(FAILED, "Set attr ATTR_NAME_STORAGE_SHAPE fail.");
  260. return FAILED;
  261. }
  262. desc.SetShape(dynamic_shape);
  263. desc.SetShapeRange(dynamic_shape_range);
  264. }
  265. inputTensor.SetTensorDesc(desc);
  266. inputs_dynamic.push_back(inputTensor);
  267. }
  268. return SUCCESS;
  269. }
  270. class GeGenerator::Impl {
  271. public:
  272. Impl(OmgContext &omg_context) : omg_context_(omg_context) {}
  273. ~Impl() = default;
  274. Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GeRootModelPtr &ge_models);
  275. Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model);
  276. Status SaveRootModel(const string &file_name_prefix, GeRootModelPtr &model, ModelBufferData &model_buff);
  277. Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  278. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
  279. Status GenerateInfershapeGraph(const Graph &graph);
  280. OmgContext &omg_context_;
  281. GraphManager graph_manager_;
  282. SaveParam save_param_;
  283. bool is_offline_ = true;
  284. bool is_singleop_unregistered_ = false;
  285. std::string build_mode_;
  286. std::string build_step_;
  287. static std::mutex mutex_;
  288. private:
  289. static std::string Trim(const std::string &str);
  290. bool ParseVersion(const std::string &line, std::string &version);
  291. bool GetVersionFromPath(const std::string &file_path, std::string &version);
  292. bool SetAtcVersionInfo(AttrHolder &obj);
  293. bool SetOppVersionInfo(AttrHolder &obj);
  294. bool SetOmSystemInfo(AttrHolder &obj);
  295. };
  296. Status GeGenerator::Initialize(const map<string, string> &options) {
  297. return Initialize(options, domi::GetContext());
  298. }
  299. Status GeGenerator::Initialize(const map<string, string> &options, OmgContext &omg_context) {
  300. impl_ = ge::MakeShared<Impl>(omg_context);
  301. if (impl_ == nullptr) {
  302. GELOGE(MEMALLOC_FAILED, "Make shared failed");
  303. return MEMALLOC_FAILED;
  304. }
  305. ErrorManager::GetInstance().SetStage(ErrorMessage::kInitialize, ErrorMessage::kOpsProtoInit);
  306. string opsproto_path;
  307. GetOpsProtoPath(opsproto_path);
  308. GELOGI("Get opsproto path is %s", opsproto_path.c_str());
  309. OpsProtoManager *manager = OpsProtoManager::Instance();
  310. map<string, string> option_tmp;
  311. option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
  312. (void)manager->Initialize(option_tmp);
  313. Status ret = impl_->graph_manager_.Initialize(options);
  314. if (ret != SUCCESS) {
  315. GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "Graph manager initialize failed.");
  316. return GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED;
  317. }
  318. // get ek file
  319. auto iter = options.find(EK_FILE);
  320. if (iter != options.end()) {
  321. impl_->save_param_.ek_file = iter->second;
  322. }
  323. // get cert file
  324. iter = options.find(CERT_FILE);
  325. if (iter != options.end()) {
  326. impl_->save_param_.cert_file = iter->second;
  327. }
  328. // get hw key file
  329. iter = options.find(HW_KEY_FILE);
  330. if (iter != options.end()) {
  331. impl_->save_param_.hw_key_file = iter->second;
  332. }
  333. // get private file
  334. iter = options.find(PRIVATE_KEY_FILE);
  335. if (iter != options.end()) {
  336. impl_->save_param_.pri_key_file = iter->second;
  337. }
  338. // get build mode
  339. iter = options.find(BUILD_MODE);
  340. if (iter != options.end()) {
  341. impl_->build_mode_ = iter->second;
  342. }
  343. // get build step
  344. iter = options.find(BUILD_STEP);
  345. if (iter != options.end()) {
  346. impl_->build_step_ = iter->second;
  347. }
  348. return SUCCESS;
  349. }
  350. Status GeGenerator::Finalize() {
  351. ErrorManager::GetInstance().SetStage(ErrorMessage::kFinalize, ErrorMessage::kFinalize);
  352. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  353. Status ret = impl_->graph_manager_.Finalize();
  354. if (ret != SUCCESS) {
  355. GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "Graph manager finalize failed.");
  356. return GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED;
  357. }
  358. return SUCCESS;
  359. }
  360. Status GeGenerator::GenerateOfflineModel(const Graph &graph, const string &file_name_prefix,
  361. const vector<GeTensor> &inputs) {
  362. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  363. GELOGI("Start to generate offline model.");
  364. ModelBufferData model;
  365. return GenerateModel(graph, file_name_prefix, inputs, model, true);
  366. }
  367. Status GeGenerator::GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ModelBufferData &model) {
  368. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  369. return GenerateModel(graph, "online", inputs, model, false);
  370. }
  371. Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) {
  372. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  373. Status ret = impl_->GenerateInfershapeGraph(graph);
  374. if (ret != SUCCESS) {
  375. GELOGE(ret, "Dump infershape json failed");
  376. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  377. GELOGE(FAILED, "graph_manager finalize fail.");
  378. }
  379. return ret;
  380. }
  381. GELOGI("Generate infer shape graph success");
  382. return SUCCESS;
  383. }
  384. std::mutex GeGenerator::Impl::mutex_;
  385. // Remove the space and tab before and after the string
  386. std::string GeGenerator::Impl::Trim(const std::string &str) {
  387. if (str.empty()) {
  388. return str;
  389. }
  390. std::string::size_type start = str.find_first_not_of(" \t\r\n");
  391. if (start == std::string::npos) {
  392. return str;
  393. }
  394. std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1;
  395. return str.substr(start, end);
  396. }
  397. // Parsing the command line
  398. bool GeGenerator::Impl::ParseVersion(const std::string &line, std::string &version) {
  399. std::string flag = "Version=";
  400. std::string temp = Trim(line);
  401. if (temp.empty()) {
  402. GELOGW("line is empty.");
  403. return false;
  404. }
  405. std::string::size_type pos = temp.find(flag);
  406. if (pos == std::string::npos) {
  407. GELOGW("Incorrect line [%s], it must include [%s].", line.c_str(), flag.c_str());
  408. return false;
  409. }
  410. if (temp.size() == flag.size()) {
  411. GELOGW("version information is empty. %s", line.c_str());
  412. return false;
  413. }
  414. version = temp.substr(pos + flag.size());
  415. return true;
  416. }
  417. bool GeGenerator::Impl::GetVersionFromPath(const std::string &file_path, std::string &version) {
  418. // Normalize the path
  419. string resolved_file_path = RealPath(file_path.c_str());
  420. if (resolved_file_path.empty()) {
  421. GELOGW("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str());
  422. return false;
  423. }
  424. std::ifstream fs(resolved_file_path, std::ifstream::in);
  425. if (!fs.is_open()) {
  426. GELOGW("Open %s failed.", file_path.c_str());
  427. return false;
  428. }
  429. std::string line;
  430. if (getline(fs, line)) {
  431. if (!ParseVersion(line, version)) {
  432. GELOGW("Parse version failed. content is [%s].", line.c_str());
  433. fs.close();
  434. return false;
  435. }
  436. } else {
  437. GELOGW("No version information found in the file path:%s", file_path.c_str());
  438. fs.close();
  439. return false;
  440. }
  441. fs.close(); // close the file
  442. return true;
  443. }
  444. // Set package version information in the model
  445. bool GeGenerator::Impl::SetAtcVersionInfo(AttrHolder &obj) {
  446. std::string path_base = ge::GELib::GetPath();
  447. path_base = path_base.substr(0, path_base.rfind('/'));
  448. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  449. std::string version_path = path_base + "version.info";
  450. std::string version;
  451. if (!GetVersionFromPath(version_path, version)) {
  452. GELOGW("Get atc version information failed!");
  453. return false;
  454. }
  455. // set version info
  456. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_ATC_VERSION, version)) {
  457. GELOGW("Ge model set atc version failed!");
  458. return false;
  459. }
  460. return true;
  461. }
  462. // Set package version information in the model
  463. bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) {
  464. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  465. if (path_env == nullptr) {
  466. GELOGW("Get environment variable ASCEND_OPP_PATH failed!");
  467. return false;
  468. }
  469. std::string version_path = path_env;
  470. version_path += "/version.info";
  471. std::string version;
  472. if (!GetVersionFromPath(version_path, version)) {
  473. GELOGW("Get opp version information failed!");
  474. return false;
  475. }
  476. // set version info
  477. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_OPP_VERSION, version)) {
  478. GELOGW("Ge model set opp version failed!");
  479. return false;
  480. }
  481. return true;
  482. }
  483. bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) {
  484. std::string soc_version;
  485. (void)ge::GetContext().GetOption(ge::SOC_VERSION, soc_version);
  486. GELOGI("SetOmSystemInfo soc_version: %s", soc_version.c_str());
  487. if (!ge::AttrUtils::SetStr(obj, "soc_version", soc_version)) {
  488. GELOGW("SetStr of soc_version failed.");
  489. return false;
  490. }
  491. std::string framework_type;
  492. (void)ge::GetContext().GetOption(ge::FRAMEWORK_TYPE, framework_type);
  493. GELOGI("SetOmSystemInfo framework_type: %s", framework_type.c_str());
  494. auto iter = ge::kFwkTypeToStr.find(framework_type);
  495. if (iter == ge::kFwkTypeToStr.end()) {
  496. GELOGW("Can not find framework_type in the map.");
  497. return false;
  498. }
  499. if (!ge::AttrUtils::SetStr(obj, "framework_type", iter->second)) {
  500. GELOGW("SetStr of framework_type failed.");
  501. return false;
  502. }
  503. return true;
  504. }
  505. Status GeGenerator::SetModelNameForDump(const GeRootModelPtr &ge_root_model) {
  506. bool is_unknown_shape = false;
  507. Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  508. if (ret != SUCCESS) {
  509. GELOGE(FAILED, "[Check][IsUnknownShape]Check root model is unknown shape failed, model id:%u",
  510. ge_root_model->GetModelId());
  511. REPORT_CALL_ERROR("E19999", "Check root model is unknown shape failed, model id:%u",
  512. ge_root_model->GetModelId());
  513. return FAILED;
  514. }
  515. GeModelPtr model_root = nullptr;
  516. if (is_unknown_shape) {
  517. model_root = MakeShared<GeModel>();
  518. GE_CHECK_NOTNULL(model_root);
  519. model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
  520. ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
  521. }
  522. ModelHelper model_helper;
  523. string model_name;
  524. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  525. Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
  526. model_name);
  527. if (name_ret != SUCCESS) {
  528. ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
  529. GELOGE(FAILED, "[Check][GetModelNameStep]Get model_name failed. Param --output is invalid, root graph name: %s",
  530. ge_root_model->GetRootGraph()->GetName().c_str());
  531. return PARAM_INVALID;
  532. }
  533. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  534. GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  535. GE_CHECK_NOTNULL(ge_model);
  536. ge_model->SetName(model_name);
  537. return SUCCESS;
  538. }
  539. Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
  540. ModelBufferData &model, bool is_offline) {
  541. rtContext_t ctx = nullptr;
  542. auto rt = rtCtxGetCurrent(&ctx);
  543. if (rt != RT_ERROR_NONE) {
  544. GELOGD("Current ctx is null.");
  545. ctx = nullptr;
  546. }
  547. GeRootModelPtr ge_root_model = nullptr;
  548. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  549. impl_->is_offline_ = is_offline;
  550. Status ret = impl_->BuildModel(graph, inputs, ge_root_model);
  551. if (ret != SUCCESS) {
  552. GELOGE(ret, "Build model failed.");
  553. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  554. GELOGE(FAILED, "graph_manager finalize fail.");
  555. }
  556. return ret;
  557. }
  558. /// BUILD_MODE_TUNING with BUILD_STEP_BEFORE_UB_MATCH no need save model;
  559. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER no need save model;
  560. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need save model.
  561. if ((impl_->build_mode_ == BUILD_MODE_TUNING) &&
  562. (impl_->build_step_ == BUILD_STEP_BEFORE_UB_MATCH || impl_->build_step_ == BUILD_STEP_AFTER_BUILDER ||
  563. impl_->build_step_ == BUILD_STEP_AFTER_BUILDER_SUB)) {
  564. GELOGI("Build mode:%s with step:%s no need SaveModel.",
  565. impl_->build_mode_.c_str(),
  566. impl_->build_step_.c_str());
  567. return SUCCESS;
  568. }
  569. GE_CHECK_NOTNULL(ge_root_model);
  570. ret = SetModelNameForDump(ge_root_model);
  571. if (ret != SUCCESS) {
  572. return ret;
  573. }
  574. ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
  575. if (ret != SUCCESS) {
  576. GELOGE(ret, "Save model failed");
  577. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  578. GELOGE(FAILED, "graph_manager finalize fail.");
  579. }
  580. return ret;
  581. }
  582. if (ctx != nullptr) {
  583. (void)rtCtxSetCurrent(ctx);
  584. }
  585. return SUCCESS;
  586. }
  587. namespace {
  588. bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tensor_desc) {
  589. bool is_need = true;
  590. // format and dtype is all reserved, stand for Optional input. When singleop scene
  591. if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) {
  592. is_need = false;
  593. }
  594. return is_need;
  595. }
  596. Status CheckDynamicSupport(GeModelPtr &ge_model, const ComputeGraphPtr &graph) {
  597. bool support_dynamic = true;
  598. bool is_dynamic = false;
  599. for (const auto &node : graph->GetDirectNode()) {
  600. GE_CHECK_NOTNULL(node);
  601. auto op_desc = node->GetOpDesc();
  602. GE_CHECK_NOTNULL(op_desc);
  603. if (op_desc->GetOpEngineName() != kAIcoreEngine) {
  604. continue;
  605. }
  606. if (AttrUtils::HasAttr(op_desc, kAttrSupportDynamicShape)) {
  607. is_dynamic = true;
  608. (void) AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic);
  609. if (!support_dynamic) {
  610. GELOGW("Node[%s] doesn't support dynamic shape.", node->GetName().c_str());
  611. break;
  612. }
  613. }
  614. }
  615. if (is_dynamic) {
  616. (void) AttrUtils::SetBool(ge_model, kAttrSupportDynamicShape, support_dynamic);
  617. }
  618. return SUCCESS;
  619. }
  620. }
  621. bool GeGenerator::CheckNoAicore(const ComputeGraphPtr &graph) {
  622. for (const auto &node : graph->GetDirectNode()) {
  623. if (node == nullptr) {
  624. continue;
  625. }
  626. auto op_desc = node->GetOpDesc();
  627. if (op_desc == nullptr) {
  628. continue;
  629. }
  630. if (op_desc->GetOpEngineName() == kAIcoreEngine) {
  631. return false;
  632. }
  633. }
  634. return true;
  635. }
  636. void GeGenerator::RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs) {
  637. for (auto &input : inputs) {
  638. GeTensorDesc input_desc = input.GetTensorDesc();
  639. bool is_const = false;
  640. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  641. if (!is_const) {
  642. outputs.emplace_back(input);
  643. }
  644. }
  645. }
  646. Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  647. const vector<GeTensor> &outputs) {
  648. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  649. if (!inputs.empty() && (inputs.size() != op_desc->GetAllInputsSize())) {
  650. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  651. {op_desc->GetName(), op_desc->GetType(), "inputs size" + FmtToStr(op_desc->GetAllInputsSize()),
  652. "tensor size is " + FmtToStr(inputs.size())});
  653. GELOGE(PARAM_INVALID, "Tensor size: %zu, Inputs size: %zu", inputs.size(), op_desc->GetAllInputsSize());
  654. return PARAM_INVALID;
  655. }
  656. if (!outputs.empty() && (outputs.size() != op_desc->GetOutputsSize())) {
  657. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  658. {op_desc->GetName(), op_desc->GetType(), "outputs size" + FmtToStr(op_desc->GetOutputsSize()),
  659. "tensor size is " + FmtToStr(outputs.size())});
  660. GELOGE(PARAM_INVALID, "Tensor size: %zu, Outputs size: %zu", outputs.size(), op_desc->GetOutputsSize());
  661. return PARAM_INVALID;
  662. }
  663. return SUCCESS;
  664. }
  665. Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
  666. const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
  667. bool is_offline) {
  668. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  669. impl_->is_offline_ = is_offline;
  670. if (!is_offline) {
  671. (void)AttrUtils::SetBool(op_desc, ATTR_SINGLE_OP_SCENE, true);
  672. }
  673. if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
  674. GELOGE(PARAM_INVALID, "input param is invalid when build single op!");
  675. return PARAM_INVALID;
  676. }
  677. OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_;
  678. omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc);
  679. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  680. impl_->is_singleop_unregistered_ = true;
  681. }
  682. // 0. Save original attributes.
  683. OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc);
  684. GE_CHECK_NOTNULL(op_desc_tmp);
  685. // 1. Create ComputeGraph.
  686. string name = ge::CurrentTimeInStr() + "_" + model_file_name;
  687. Graph graph;
  688. GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), "make graph fail.");
  689. // 2. check engine type when compile online
  690. if (model_file_name == kFileNameSuffix) {
  691. auto comp_graph = GraphUtils::GetComputeGraph(graph);
  692. GE_CHECK_NOTNULL(comp_graph);
  693. auto node = comp_graph->FindNode(op_desc->GetName());
  694. Status ret = CheckEngineTypeSupport(node, engine_type);
  695. if (ret != SUCCESS) {
  696. GELOGE(ret, "[Check][EngineType]value:%d for node:%s not support", engine_type, node->GetName().c_str());
  697. return ret;
  698. }
  699. }
  700. GELOGI("ATC parser success in single op build.");
  701. GeRootModelPtr ge_root_model = nullptr;
  702. vector<GeTensor> data_inputs;
  703. RemoveConst(inputs, data_inputs);
  704. GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, data_inputs, ge_root_model));
  705. map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs();
  706. GE_CHECK_NOTNULL(ge_root_model);
  707. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  708. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  709. if (name_to_ge_model.empty()) {
  710. GELOGE(PARAM_INVALID, "GetSubgraphInstanceNameToModel is empty.");
  711. return PARAM_INVALID;
  712. }
  713. const ComputeGraphPtr root_graph = ge_root_model->GetRootGraph();
  714. GeModelPtr &ge_model = name_to_ge_model.begin()->second;
  715. GE_CHK_STATUS_RET_NOLOG(CheckDynamicSupport(ge_model, root_graph));
  716. GELOGI("After build model, The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str());
  717. bool all_shape = false;
  718. (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape);
  719. if (all_shape && CheckNoAicore(root_graph)) {
  720. GELOGD("Get aicpu all_shape kernel!");
  721. vector<GeTensor> inputs_dynamic;
  722. vector<GeTensor> outputs_dynamic;
  723. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(inputs, inputs_dynamic));
  724. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(outputs, outputs_dynamic));
  725. GE_CHK_STATUS_RET_NOLOG(
  726. impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic));
  727. } else {
  728. GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
  729. }
  730. GELOGI("Start save GeModel to Model buffer");
  731. GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff));
  732. return SUCCESS;
  733. }
  734. /**
  735. * @ingroup ge
  736. * @brief Compiling a single operator into an offline model
  737. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  738. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  739. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  740. * @param [in] const string &model_file_name: Offline model filename.
  741. * @return SUCCESS handle successfully / others handle failed
  742. */
  743. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  744. const vector<GeTensor> &outputs, const string &model_file_name) {
  745. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  746. GELOGI("Start to build single op offline model, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  747. ModelBufferData model_buff;
  748. OpEngineType engine_type = ENGINE_SYS;
  749. Status status = BuildSingleOp(op_desc, inputs, outputs, model_file_name, engine_type, model_buff, true);
  750. GELOGI("Finish build single offline model, status: %u", status);
  751. return status;
  752. }
  753. /**
  754. * @ingroup ge
  755. * @brief Compiling a single operator into online buffer
  756. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  757. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  758. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  759. * @param [in] engine_type: specific engine.
  760. * @param [in] compile_flag: op build flag, compile flag by acl
  761. * @param [out] ModelBufferData &Model_buff: Model_buff: model buffer of the op.
  762. * @return SUCCESS handle successfully / others handle failed
  763. */
  764. // old process will be deleted
  765. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  766. const vector<GeTensor> &outputs, OpEngineType engine_type,
  767. ModelBufferData &model_buff) {
  768. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  769. GELOGI("Start to build single op online, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  770. Status status = BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false);
  771. GELOGI("Finish build single online model, status: %u", status);
  772. return status;
  773. }
  774. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  775. const vector<GeTensor> &outputs, OpEngineType engine_type, int32_t compile_flag,
  776. ModelBufferData &model_buff) {
  777. return SUCCESS;
  778. }
  779. Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  780. const vector<GeTensor> &outputs, std::string graph_name, Graph &graph) {
  781. ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(graph_name);
  782. GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR);
  783. // 1. Add Node to ComputeGraph.
  784. NodePtr op_node = compute_graph->AddNode(op_desc);
  785. GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR);
  786. // 2. Create InputData node.
  787. int32_t arg_index = 0;
  788. int32_t data_index = 0;
  789. if (inputs.empty()) {
  790. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  791. GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
  792. if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
  793. continue;
  794. }
  795. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false, data_index));
  796. arg_index++;
  797. }
  798. } else {
  799. for (const auto &in_desc : inputs) {
  800. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true, data_index));
  801. arg_index++;
  802. }
  803. }
  804. // 3. Create Output node.
  805. if (!outputs.empty()) {
  806. GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs));
  807. }
  808. // dump ComputeGraph node.
  809. compute_graph->Dump();
  810. graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  811. return SUCCESS;
  812. }
  813. Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  814. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) {
  815. GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID);
  816. GE_CHK_BOOL_EXEC_NOLOG(graph_manager_.SaveParams(*ge_model, type, attrs, inputs, outputs) == SUCCESS,
  817. (void)graph_manager_.Finalize();
  818. return FAILED);
  819. return SUCCESS;
  820. }
  821. Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) {
  822. // set atc version
  823. if (!SetAtcVersionInfo(*(model.get()))) {
  824. GELOGW("SetPackageVersionInfo of atc failed!");
  825. }
  826. // set opp version
  827. if (!SetOppVersionInfo(*(model.get()))) {
  828. GELOGW("SetPackageVersionInfo of ops failed!");
  829. }
  830. ModelHelper model_helper;
  831. model_helper.SetSaveMode(is_offline_);
  832. Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff);
  833. if (ret != SUCCESS) {
  834. GELOGE(ret, "Save to om model failed");
  835. return ret;
  836. }
  837. return SUCCESS;
  838. }
  839. Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootModelPtr &ge_root_model,
  840. ModelBufferData &model_buff) {
  841. bool is_unknown_shape = false;
  842. auto ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  843. if (ret != SUCCESS) {
  844. GELOGE(FAILED, "Check root model is unkonwn shape failed");
  845. return FAILED;
  846. }
  847. GELOGD("begin save root model, cur model is unkonwn shape model ? : %d", is_unknown_shape);
  848. GE_CHK_BOOL_EXEC(!ge_root_model->GetSubgraphInstanceNameToModel().empty(), return FAILED,
  849. "ge root model has no sub model")
  850. GeModelPtr model_root = nullptr;
  851. if (is_unknown_shape) {
  852. auto name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  853. model_root = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  854. } else {
  855. model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second;
  856. }
  857. GE_CHECK_NOTNULL(model_root);
  858. // set atc version
  859. if (!SetAtcVersionInfo(*(model_root.get()))) {
  860. GELOGW("SetPackageVersionInfo of atc failed!");
  861. }
  862. // set opp version
  863. if (!SetOppVersionInfo(*(model_root.get()))) {
  864. GELOGW("SetPackageVersionInfo of ops failed!");
  865. }
  866. if (!SetOmSystemInfo(*(model_root.get()))) {
  867. GELOGW("SetOmsystemInfo failed!");
  868. }
  869. ModelHelper model_helper;
  870. model_helper.SetSaveMode(is_offline_);
  871. ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape);
  872. if (ret != SUCCESS) {
  873. GELOGE(ret, "Save to om model failed");
  874. return ret;
  875. }
  876. return SUCCESS;
  877. }
  878. Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs,
  879. GeRootModelPtr &ge_root_model) {
  880. static std::atomic<GraphId> atomic_graph_id(0);
  881. auto graph_id = atomic_graph_id.fetch_add(1);
  882. const std::map<std::string, std::string> options;
  883. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  884. if (ret != SUCCESS) {
  885. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph fail, graph id: %u", graph_id);
  886. (void)graph_manager_.Finalize();
  887. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  888. }
  889. graph_manager_.SetOptionsRunGraphFlag(false);
  890. static std::atomic<uint64_t> atomic_session_id(0);
  891. auto session_id = atomic_session_id.fetch_add(1);
  892. // This is a temporary add for graph with variable
  893. auto version = static_cast<int32_t>(SessionVersion::ClOUD_VERSION);
  894. ret = VarManager::Instance(session_id)->Init(version, session_id, kDefaultDeviceId, kDefaultJobId);
  895. GELOGI("Start init var instance, session_id %lu", session_id);
  896. if (ret != SUCCESS) {
  897. GELOGW("Failed init var instance, session_id %lu", session_id);
  898. }
  899. if (is_singleop_unregistered_) {
  900. ret = graph_manager_.BuildGraphForUnregisteredOp(graph_id, inputs, ge_root_model, session_id);
  901. } else {
  902. ret = graph_manager_.BuildGraph(graph_id, inputs, ge_root_model, session_id);
  903. }
  904. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  905. if (ret != SUCCESS) {
  906. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager build graph fail, graph id: %u", graph_id);
  907. ret = GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  908. }
  909. RtContextUtil::GetInstance().DestroyRtContexts(session_id);
  910. Analyzer::GetInstance()->DestroySessionJsonObject(session_id);
  911. VarManagerPool::Instance().RemoveVarManager(session_id);
  912. return ret;
  913. }
  914. Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph) {
  915. static std::atomic<GraphId> atomic_graph_id(0);
  916. auto graph_id = atomic_graph_id.fetch_add(1);
  917. const std::map<std::string, std::string> options;
  918. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  919. if (ret != SUCCESS) {
  920. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph failed, graph id: %u", graph_id);
  921. (void)graph_manager_.Finalize();
  922. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  923. }
  924. ret = graph_manager_.GenerateInfershapeGraph(graph_id);
  925. if (ret != SUCCESS) {
  926. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager generate graph failed");
  927. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  928. }
  929. return SUCCESS;
  930. }
  931. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示