You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.cc 58 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <dlfcn.h>
  18. #include <gflags/gflags.h>
  19. #include <sys/types.h>
  20. #include <unistd.h>
  21. #include <cctype>
  22. #include <climits>
  23. #include <cstdlib>
  24. #include <iostream>
  25. #include "common/gflags_util.h"
  26. #include "common/util.h"
  27. #include "common/util/error_manager/error_manager.h"
  28. #include "framework/common/debug/ge_log.h"
  29. #include "ge/ge_api.h"
  30. #include "generator/ge_generator.h"
  31. #include "graph/anchor.h"
  32. #include "graph/debug/ge_attr_define.h"
  33. #include "graph/graph.h"
  34. #include "graph/op_desc.h"
  35. #include "graph/utils/graph_utils.h"
  36. #include "graph/utils/type_utils.h"
  37. #include "init/gelib.h"
  38. #include "ir_build/atc_ir_common.h"
  39. #include "omg/omg.h"
  40. #include "omg/parser/parser_factory.h"
  41. #include "omg/parser/parser_inner_ctx.h"
  42. #include "parser/common/register_tbe.h"
  43. #include "register/op_registry.h"
  44. #include "single_op_parser.h"
  45. using domi::BuildMode;
  46. using domi::OpRegistrationData;
  47. using domi::OpRegistry;
  48. using domi::Status;
  49. using domi::SUCCESS;
  50. using ge::GEN_OM_MODEL;
  51. using ge::GflagsUtils;
  52. using ge::MODEL_TO_JSON;
  53. using ge::ONLY_PRE_CHECK;
  54. using ge::ParseInputShape;
  55. using ge::PBTXT_TO_JSON;
  56. using std::map;
  57. using std::pair;
  58. using std::shared_ptr;
  59. using std::string;
  60. using std::vector;
  61. static bool is_dynamic_input = false;
  62. // 310 limited 8G size
  63. const char *const kGraphMemoryManagerMallocMaxSize = "8*1024*1024*1024";
  64. const char *const kModeSupport = "only support 0(model to framework model), "
  65. "1(framework model to json), 3(only pre-check), 5(pbtxt to json)";
  66. const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow)";
  67. // limit available mem size 2G
  68. const long kMinAvailableMem = 2 * 1024 * 1024;
  69. DEFINE_string(model, "", "The model file.");
  70. DEFINE_string(output, "", "The output file path&name.");
  71. DEFINE_int32(framework, -1, "Framework type(0:Caffe; 1:MindSpore; 3:Tensorflow).");
  72. DEFINE_string(weight, "", "Optional; weight file. Required when framework is Caffe.");
  73. DEFINE_string(input_shape, "",
  74. "Optional; shape of input data. Required when framework is caffe "
  75. "or TensorFLow or MindSpore."
  76. "Format: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"");
  77. DEFINE_bool(h, false, "show this help message");
  78. DEFINE_string(cal_conf, "", "Optional; the calibration config file.");
  79. DEFINE_string(insert_op_conf, "", "Optional; the config file to insert new op, for example AIPP op.");
  80. DEFINE_string(op_name_map, "", "Optional; custom op name mapping file.");
  81. DEFINE_string(target, "", "Optional; mini.");
  82. DEFINE_string(om, "", "The model file to be converted to json.");
  83. DEFINE_string(json, "", "The output json file path&name which is converted from a model.");
  84. DEFINE_int32(mode, 0,
  85. "Optional; run mode, 0(default): model => framework model; 1: "
  86. "framework model => json; 3: only pre-check; 5: pbtxt => json.");
  87. #if !defined(__ANDROID__) && !defined(ANDROID)
  88. DEFINE_int32(encrypt_mode, -1, "Optional; the encrypt flag. 0: encrypt; -1(default): not encrypt");
  89. DEFINE_string(encrypt_key, "", "Optional; the encrypt_key file.");
  90. DEFINE_string(certificate, "", "Optional; the certificate file.");
  91. DEFINE_string(hardware_key, "", "Optional; the ISV key file.");
  92. DEFINE_string(private_key, "", "Optional; the private key file.");
  93. #endif
  94. DEFINE_string(out_nodes, "",
  95. "Optional; output nodes designated by users."
  96. "Format: \"node_name1:0;node_name1:1;node_name2:0\"");
  97. DEFINE_string(precision_mode, "force_fp16",
  98. "Optional; precision mode."
  99. "Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype.");
  100. DEFINE_string(input_format, "",
  101. "Optional; input_format, format of input data, NCHW;NHWC."
  102. "Format:\"NHWC\"");
  103. DEFINE_string(check_report, "check_result.json", "Optional; the pre-checking report file.");
  104. DEFINE_string(input_fp16_nodes, "",
  105. "Optional; input node datatype is fp16 and format is NC1HWC0."
  106. "Format:\"node_name1;node_name2\"");
  107. DEFINE_string(is_output_adjust_hw_layout, "",
  108. "Optional; Net output node's datatype is fp16 and format is "
  109. "NC1HWC0, or not."
  110. "Format:\"false,true,false,true\"");
  111. DEFINE_string(is_input_adjust_hw_layout, "",
  112. "Optional; Intput node's datatype is fp16 and format is "
  113. "NC1HWC0, or not."
  114. "Format:\"false,true,false,true\"");
  115. DEFINE_string(output_type, "",
  116. "Optional; output type! "
  117. "Support FP32,FP16,INT8,INT16,UINT16,UINT8,INT32,INT64,UINT32,UINT64,DOUBLE.");
  118. DEFINE_string(op_select_implmode, "",
  119. "Optional; op select implmode! "
  120. "Support high_precision, high_performance.");
  121. DEFINE_string(optypelist_for_implmode, "",
  122. "Optional; Nodes need use implmode selected in op_select_implmode "
  123. "Format:\"node_name1,node_name2\"");
  124. DEFINE_string(singleop, "", "Optional; If set, generate single op model with the given json file.");
  125. DEFINE_int32(disable_reuse_memory, 0, "Optional; If set to 1, disable reuse memory when generating if.");
  126. DEFINE_string(auto_tune_mode, "", "Optional; Set tune mode.");
  127. DEFINE_string(soc_version, "", "The soc version.");
  128. DEFINE_string(core_type, "AiCore", "Optional; If set to VectorCore, only use vector core.");
  129. DEFINE_string(aicore_num, "", "Optional; Set aicore num");
  130. DEFINE_string(buffer_optimize, "l2_optimize", "Optional; buffer optimize");
  131. DEFINE_string(fusion_switch_file, "", "Optional; Set fusion switch file path");
  132. DEFINE_string(save_original_model, "", "Optional; enable output original offline model. false(default)");
  133. DEFINE_string(dynamic_batch_size, "",
  134. "Optional; If set, generate dynamic multi batch model. "
  135. "Different batch sizes are split by ','."
  136. "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one.");
  137. DEFINE_string(dynamic_image_size, "",
  138. "Optional; If set, generate dynamic multi image size model."
  139. "Different groups of image size are split by ';',"
  140. "while different dimensions of each group are split by ','."
  141. "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one.");
  142. DEFINE_string(dynamic_dims, "",
  143. "Optional; If set, generate dynamic input size model. "
  144. "Different groups of size are split by ';', while different dimensions of each group are split by ','."
  145. "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one.");
  146. DEFINE_string(enable_small_channel, "0", "Optional; If set to 1, small channel is enabled.");
  147. DEFINE_string(enable_compress_weight, "false",
  148. "Optional; enable compress weight. true: enable; false(default): disable");
  149. DEFINE_string(compress_weight_conf, "", "Optional; the config file to compress weight");
  150. DEFINE_string(enable_single_stream, "", "Optional; enable single stream. true: enable; false(default): disable");
  151. DEFINE_string(log, "null", "Optional; generate atc log. Support debug, info, warning, error, null");
  152. DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0.");
  153. DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug;"
  154. "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler");
  155. DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass,"
  156. "multiple names can be set and separated by ','.");
  157. DEFINE_string(debug_dir, "", "Optional; the path to save the intermediate files of operator compilation");
  158. DEFINE_string(op_compiler_cache_dir, "", "Optional; the path to cache operator compilation files");
  159. DEFINE_string(op_compiler_cache_mode, "", "Optional; choose the operator compiler cache mode");
  160. class GFlagUtils {
  161. public:
  162. /**
  163. * @name InitGFlag
  164. * @brief initialize gflag
  165. * @return void
  166. */
  167. static void InitGFlag(int argc, char *argv[]) {
  168. // -help
  169. gflags::SetUsageMessage(
  170. "usage: ./atc <args>\n"
  171. "generate offline model example:\n"
  172. "./atc --model=./alexnet.prototxt --weight=./alexnet.caffemodel \n"
  173. "--framework=0 --output=./domi \n"
  174. "generate offline model for single op example:\n"
  175. "./atc --singleop=./op_list.json --output=./op_model \n"
  176. "===== Basic Functionality =====\n"
  177. "[General]\n"
  178. " --h/help Show this help message\n"
  179. " --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format "
  180. "3: only pre-check; 5: convert pbtxt file to JSON format\n"
  181. "\n[Input]\n"
  182. " --model Model file\n"
  183. " --weight Weight file. Required when framework is Caffe\n"
  184. " --om The model file to be converted to json\n"
  185. " --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow\n"
  186. " --input_format Format of input data. E.g.: \"NCHW\"\n"
  187. " --input_shape Shape of input data. Separate multiple nodes with semicolons (;)."
  188. "Use double quotation marks (\") to enclose each argument.\n"
  189. " E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n"
  190. " --dynamic_batch_size Set dynamic batch size. E.g: \"batchsize1,batchsize2,batchsize3\"\n"
  191. " --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;)."
  192. "Use double quotation marks (\") to enclose each argument.\n"
  193. " E.g: \"imagesize1_height,imagesize1_width;imagesize2_height,imagesize2_width\"\n"
  194. " --dynamic_dims Set dynamic dims. Separate multiple nodes with semicolons (;)."
  195. "Use double quotation marks (\") to enclose each argument. E.g: \"dims1_n1,dims1_n2;dims2_n1,dims2_n2\"\n"
  196. " --singleop Single op definition file. atc will generate offline "
  197. "model(s) for single op if --singleop is set.\n"
  198. "\n[Output]\n"
  199. " --output Output file path&name(needn't suffix, will add "
  200. ".om automatically). \n"
  201. " If --singleop is set, this arg specifies the directory to "
  202. "which the single op offline model will be generated\n"
  203. " --output_type Set net output type. Support FP32, FP16, UINT8."
  204. "E.g.: FP16, indicates that all out nodes are set to FP16.\n"
  205. " \"node1:0:FP16;node2:1:FP32\", indicates setting the datatype of multiple out nodes.\n"
  206. " --check_report The pre-checking report file. Default value is: "
  207. "\"check_result.json\"\n"
  208. " --json The output json file path&name which is "
  209. "converted from a model\n"
  210. "\n[Target]\n"
  211. " --soc_version The soc version.\n"
  212. " --core_type Set core type AiCore or VectorCore. VectorCore: use vector core. "
  213. "Default value is: AiCore\n"
  214. " --aicore_num Set aicore num\n"
  215. "===== Advanced Functionality =====\n"
  216. "[Feature]\n"
  217. " --out_nodes Output nodes designated by users. Separate multiple nodes with semicolons (;)."
  218. "Use double quotation marks (\") to enclose each argument.\n"
  219. " E.g.: \"node_name1:0;node_name1:1;node_name2:0\"\n"
  220. " --input_fp16_nodes Input node datatype is fp16. Separate multiple nodes with semicolons "
  221. "(;)."
  222. "Use double quotation marks (\") to enclose each argument."
  223. "E.g.: \"node_name1;node_name2\"\n"
  224. " --insert_op_conf Config file to insert new op\n"
  225. " --op_name_map Custom op name mapping file\n"
  226. " Note: A semicolon(;) cannot be included in each "
  227. "path, otherwise the resolved path will not match the expected one.\n"
  228. " --is_input_adjust_hw_layout Intput node datatype is fp16 and format is "
  229. "NC1HWC0, used with input_fp16_nodes E.g.: \"true,true,false,true\"\n"
  230. " --is_output_adjust_hw_layout Net output node datatype is fp16 and format is "
  231. "NC1HWC0, used with out_nodes. E.g.: \"true,true,false,true\"\n"
  232. "\n[Model Tuning]\n"
  233. " --disable_reuse_memory The switch of reuse memory. Default value is : 0."
  234. "0 means reuse memory, 1 means do not reuse memory.\n"
  235. " --fusion_switch_file Set fusion switch file path\n"
  236. " --enable_scope_fusion_passes validate the non-general scope fusion passes,"
  237. "multiple names can be set and separated by ','. E.g.: ScopePass1,ScopePass2,...\n"
  238. " --enable_single_stream Enable single stream. true: enable; false(default): disable\n"
  239. " --enable_small_channel Set enable small channel. 0(default): disable; 1: enable\n"
  240. " --enable_compress_weight Enable compress weight. true: enable; false(default): disable\n"
  241. " --compress_weight_conf Config file to compress weight\n"
  242. " --buffer_optimize Set buffer optimize. \"l2_optimize\" (default). Set \"off_optimize\" to close\n"
  243. "\n[Operator Tuning]\n"
  244. " --precision_mode precision mode, support force_fp16(default), allow_mix_precision, "
  245. "allow_fp32_to_fp16, must_keep_origin_dtype.\n"
  246. " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n"
  247. " --op_select_implmode Set op select implmode. Support high_precision, high_performance."
  248. "default: high_performance\n"
  249. " --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n"
  250. " Separate multiple nodes with commas (,). Use double quotation marks (\") "
  251. " to enclose each argument. E.g.: \"node_name1,node_name2\"\n"
  252. " --op_debug_level Debug enable for TBE operator building.\n"
  253. " 0 (default): Disable debug; 1: Enable TBE pipe_all, "
  254. "and generate the operator CCE file and Python-CCE mapping file (.json);\n"
  255. " 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file "
  256. "(.json), and enable the CCE compiler -O0-g.\n"
  257. "\n[Debug]\n"
  258. " --save_original_model Control whether to output original model. E.g.: true: output original model\n"
  259. " --log Generate log with level. Support debug, info, warning, error, null\n"
  260. " --dump_mode The switch of dump json with shape, to be used with mode 1."
  261. "0(default): disable; 1: enable.");
  262. gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
  263. // Using gflags to analyze input parameters
  264. GflagsUtils::ChangeHelpFlags(FLAGS_h);
  265. gflags::HandleCommandLineHelpFlags();
  266. }
  267. static Status CheckDumpInfershapeJsonFlags() {
  268. Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight);
  269. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED,
  270. "check custom aicpu run so failed!");
  271. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  272. FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"),
  273. return domi::FAILED, "Input parameter[--weight]'s value[%s] is invalid!",
  274. FLAGS_weight.c_str());
  275. return domi::SUCCESS;
  276. }
  277. static Status CheckFlags() {
  278. Status ret = ge::SUCCESS;
  279. // No model file information passed in
  280. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  281. FLAGS_model == "",
  282. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"model"});
  283. ret = ge::FAILED, "Input parameter[--model]'s value is empty!");
  284. // check param disable_reuse_memory
  285. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  286. ge::CheckDisableReuseMemoryParamValid(to_string(FLAGS_disable_reuse_memory)) != ge::SUCCESS,
  287. ret = ge::FAILED, "check disable_reuse_memory failed!");
  288. // check optypelist_for_implmode and op_select_implmode
  289. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  290. ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode,
  291. FLAGS_op_select_implmode) != ge::SUCCESS,
  292. ret = ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!");
  293. // No output file information passed in
  294. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  295. FLAGS_mode == GEN_OM_MODEL && FLAGS_output == "",
  296. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"output"});
  297. ret = ge::FAILED, "Input parameter[--output]'s value is empty!");
  298. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  299. CheckFrameWorkValid(FLAGS_framework, FLAGS_weight) != ge::SUCCESS,
  300. ret = ge::FAILED,
  301. "CheckFrameWorkValid failed");
  302. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  303. ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size,
  304. FLAGS_dynamic_dims, FLAGS_input_shape,
  305. FLAGS_input_format, is_dynamic_input) != ge::SUCCESS,
  306. ret = ge::FAILED, "check dynamic size(batch size, image size or dims) failed!");
  307. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  308. !FLAGS_insert_op_conf.empty() && !FLAGS_dynamic_dims.empty(),
  309. ErrorManager::GetInstance().ATCReportErrMessage("E10001",
  310. {"parameter", "value", "reason"},
  311. {"--insert_op_conf", FLAGS_insert_op_conf,
  312. "dynamic dims function does not support aipp"});
  313. ret = ge::FAILED, "dynamic dims function does not support aipp");
  314. #if !defined(__ANDROID__) && !defined(ANDROID)
  315. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!CheckEncryptModeValid(FLAGS_encrypt_mode), ret = ge::FAILED,
  316. "encrypt_mode %d not valid!!", FLAGS_encrypt_mode);
  317. if (FLAGS_encrypt_mode == 0) { // Encryption mode
  318. GELOGI("ge will run with encrypt!");
  319. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_encrypt_key), ret = ge::FAILED,
  320. "encrypt_key file not found!!");
  321. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_certificate), ret = ge::FAILED,
  322. "certificate file not found!!");
  323. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_hardware_key), ret = ge::FAILED,
  324. "hardware_key file not found!!");
  325. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_private_key), ret = ge::FAILED,
  326. "private_key file not found!!");
  327. } else { // No encryption
  328. GELOGI("ge will run without encrypt!");
  329. }
  330. #endif
  331. /**
  332. * Check the validity of the I / O file path
  333. */
  334. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  335. FLAGS_model != "" && !ge::CheckInputPathValid(FLAGS_model, "--model"), ret = ge::FAILED,
  336. "model file %s not found!!", FLAGS_model.c_str());
  337. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  338. FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"),
  339. ret = ge::FAILED, "weight file %s not found!!",
  340. FLAGS_weight.c_str());
  341. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  342. FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "--cal_conf"),
  343. ret = ge::FAILED, "calibration config file %s not found!!",
  344. FLAGS_cal_conf.c_str());
  345. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  346. FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "--op_name_map"),
  347. ret = ge::FAILED, "op config file %s not found!!",
  348. FLAGS_op_name_map.c_str());
  349. GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(FLAGS_insert_op_conf)) == ge::SUCCESS,
  350. ret = ge::FAILED, "check insert op conf failed!");
  351. GE_CHK_BOOL_EXEC(ge::CheckCompressWeightParamValid(
  352. FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS,
  353. ret = ge::FAILED, "check compress weight failed!");
  354. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  355. !ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED,
  356. "check_report file %s not found!!", FLAGS_check_report.c_str());
  357. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  358. FLAGS_mode == GEN_OM_MODEL && FLAGS_output != "" &&
  359. (!ge::CheckOutputPathValid(FLAGS_output, "--output") || !CheckPathWithName(FLAGS_output)),
  360. ret = ge::FAILED, "output path %s is not valid!!", FLAGS_output.c_str());
  361. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  362. FLAGS_save_original_model != "" &&
  363. FLAGS_save_original_model != "true" &&
  364. FLAGS_save_original_model != "false",
  365. ErrorManager::GetInstance().ATCReportErrMessage(
  366. "E10005", {"parameter", "value"}, {"save_original_model", FLAGS_save_original_model});
  367. ret = ge::FAILED,
  368. "Input parameter[--save_original_model]'s value[%s] must be true or false.",
  369. FLAGS_save_original_model.c_str());
  370. GE_CHK_BOOL_EXEC(ge::CheckBufferOptimizeParamValid(FLAGS_buffer_optimize) == ge::SUCCESS,
  371. ret = ge::FAILED, "check output type failed!");
  372. GE_CHK_BOOL_EXEC(
  373. ge::CheckEnableSingleStreamParamValid(std::string(FLAGS_enable_single_stream)) == ge::SUCCESS,
  374. ret = ge::FAILED, "check enable single stream failed!");
  375. return ret;
  376. }
  377. /**
  378. * Verifying the parameters of converting model to JSON
  379. * 1. Fmk_model
  380. * 2. out_json
  381. **/
  382. static Status CheckConverJsonParamFlags() {
  383. Status ret = ge::SUCCESS;
  384. // No model path passed in
  385. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "",
  386. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"});
  387. ret = ge::FAILED,
  388. "Input parameter[--om]'s value is empty!!");
  389. // JSON path not passed in
  390. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_json == "",
  391. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"json"});
  392. ret = ge::FAILED,
  393. "Input parameter[--json]'s value is empty!!");
  394. // Check if the model path is valid
  395. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  396. FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"),
  397. ret = ge::FAILED,
  398. "model file path is invalid: %s.", FLAGS_om.c_str());
  399. // Check whether the JSON path is valid
  400. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  401. FLAGS_json != "" && !ge::CheckOutputPathValid(FLAGS_json, "--json"),
  402. ret = ge::FAILED,
  403. "json file path is invalid: %s.", FLAGS_json.c_str());
  404. return ret;
  405. }
  406. /**
  407. * Check command line parameters for explicit settings
  408. * true: Explicit setup
  409. * false: Not set up
  410. * */
  411. static bool CheckFlagSet(string flag) {
  412. gflags::CommandLineFlagInfo info;
  413. return !(gflags::GetCommandLineFlagInfo(flag.c_str(), &info) && info.is_default);
  414. }
  415. private:
  416. static bool CheckEncryptModeValid(const int encrypt_mode) {
  417. #if !defined(__ANDROID__) && !defined(ANDROID)
  418. if (encrypt_mode != 0 && encrypt_mode != -1) {
  419. DOMI_LOGE("encrypt mode must be 0 or -1");
  420. return false;
  421. }
  422. #else
  423. if (encrypt_mode != -1) {
  424. DOMI_LOGE("encrypt mode must be -1");
  425. return false;
  426. }
  427. #endif
  428. return true;
  429. }
  430. static Status CheckFrameWorkValid(int framework, const std::string weight_file) {
  431. if (framework != (int32_t)domi::CAFFE && framework != (int32_t)domi::TENSORFLOW &&
  432. framework != (int32_t)domi::MINDSPORE && framework != (int32_t)domi::ONNX) {
  433. // No framework information was passed in or the entered framework is illegal
  434. ErrorManager::GetInstance().ATCReportErrMessage(
  435. "E10007", {"parameter", "support"},
  436. {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow)"});
  437. DOMI_LOGE("Input parameter[--framework] is mandatory and it's value must be: "
  438. "0(Caffe) or 1(MindSpore) or 3(TensorFlow).");
  439. return domi::PARAM_INVALID;
  440. }
  441. if ((framework == (int32_t)domi::CAFFE) && (weight_file == "")) {
  442. ErrorManager::GetInstance().ATCReportErrMessage("E10008", {"parameter"}, {"weight"});
  443. DOMI_LOGE("Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!");
  444. return domi::PARAM_INVALID;
  445. }
  446. if ((framework == (int32_t)domi::TENSORFLOW) && (weight_file != "")) {
  447. GELOGW("Parameter weight is ignored for TensorFlow.");
  448. }
  449. if ((framework == (int32_t)domi::ONNX) && (weight_file != "")) {
  450. GELOGW("Parameter weight is ignored for Onnx.");
  451. }
  452. return domi::SUCCESS;
  453. }
  454. static bool CheckPathWithName(const std::string &fileName) {
  455. // Determine file path length
  456. if (fileName.size() > static_cast<int>(PATH_MAX)) {
  457. ErrorManager::GetInstance().ATCReportErrMessage(
  458. "E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)});
  459. GELOGE(ge::FAILED, "Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX);
  460. return false;
  461. }
  462. // Find the last separator
  463. int slashPosition = fileName.size() - 1;
  464. for (; slashPosition >= 0; slashPosition--) {
  465. if (fileName[slashPosition] == '\\' || fileName[slashPosition] == '/') {
  466. break;
  467. }
  468. }
  469. // Failure if no filename follows the path
  470. if (slashPosition == static_cast<int>(fileName.size() - 1)) {
  471. ErrorManager::GetInstance().ATCReportErrMessage("E10022", {"parameter", "filename"}, {"output", fileName});
  472. DOMI_LOGE("Input parameter[--output]'s path[%s] not include file name", fileName.c_str());
  473. return false;
  474. }
  475. return true;
  476. }
  477. };
  478. void SetDynamicInputSizeOptions() {
  479. if (!FLAGS_dynamic_batch_size.empty()) {
  480. domi::GetContext().dynamic_batch_size = FLAGS_dynamic_batch_size;
  481. }
  482. if (!FLAGS_dynamic_image_size.empty()) {
  483. domi::GetContext().dynamic_image_size = FLAGS_dynamic_image_size;
  484. }
  485. if (!FLAGS_dynamic_dims.empty()) {
  486. domi::GetContext().dynamic_dims = FLAGS_dynamic_dims;
  487. }
  488. }
  489. /// Validate the non-general scope fusion pass.
  490. /// The parameter is set to the name of the fusion rule.
  491. /// Multiple names can be set and separated by ",".
  492. void SetEnableScopeFusionPasses(const std::string pass_names) {
  493. ge::GetParserContext().enable_scope_fusion_passes = pass_names;
  494. }
  495. static bool CheckInputFormat() {
  496. if (FLAGS_input_format.empty()) {
  497. // Set default format
  498. if (FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW)) {
  499. FLAGS_input_format = "NHWC";
  500. } else {
  501. FLAGS_input_format = "NCHW";
  502. }
  503. return true;
  504. } else if ((FLAGS_framework == static_cast<int32_t>(domi::CAFFE))) { // caffe
  505. if (ge::caffe_support_input_format.find(FLAGS_input_format) != ge::caffe_support_input_format.end()) {
  506. return true;
  507. }
  508. // only support NCHW ND
  509. ErrorManager::GetInstance().ATCReportErrMessage(
  510. "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, ge::kCaffeFormatSupport});
  511. GELOGE(ge::FAILED,
  512. "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kCaffeFormatSupport);
  513. return false;
  514. } else if ((FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW))) { // tf
  515. if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) {
  516. return true;
  517. }
  518. // only support NCHW NHWC ND NCDHW NDHWC
  519. ErrorManager::GetInstance().ATCReportErrMessage(
  520. "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, ge::kTFFormatSupport});
  521. GELOGE(ge::FAILED,
  522. "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kTFFormatSupport);
  523. return false;
  524. } else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) {
  525. if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) {
  526. return true;
  527. }
  528. // only support NCHW ND
  529. ErrorManager::GetInstance().ATCReportErrMessage(
  530. "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, ge::kONNXFormatSupport});
  531. GELOGE(ge::FAILED,
  532. "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kONNXFormatSupport);
  533. return false;
  534. }
  535. return true;
  536. }
  537. #if !defined(__ANDROID__) && !defined(ANDROID)
  538. static void GetCustomOpPath(std::string &customop_path) {
  539. GELOGI("Enter get custom op path schedule");
  540. std::string fmk_type = ge::TypeUtils::FmkTypeToSerialString(static_cast<domi::FrameworkType>(FLAGS_framework));
  541. GELOGI("Framework type is %s.", fmk_type.c_str());
  542. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  543. if (path_env != nullptr) {
  544. std::string path = path_env;
  545. customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type);
  546. GELOGI("Get custom so path from env : %s", path_env);
  547. return;
  548. }
  549. std::string path_base = ge::GELib::GetPath();
  550. GELOGI("path_base is %s", path_base.c_str());
  551. path_base = path_base.substr(0, path_base.rfind('/'));
  552. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  553. customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type);
  554. return;
  555. }
  556. void GetPluginSoFileList(const string &path, vector<string> &fileList, string &caffe_parser_path) {
  557. // Support to split multiple so directories by ":"
  558. GELOGI("path is %s", path.c_str());
  559. vector<string> v_path = ge::StringUtils::Split(path, ':');
  560. for (size_t i = 0; i < v_path.size(); ++i) {
  561. ge::FindParserSo(v_path[i], fileList, caffe_parser_path);
  562. GELOGI("CustomOpLib full name = %s", v_path[i].c_str());
  563. }
  564. }
  565. void LoadModelParserLib(std::string caffe_parser_path) {
  566. if (FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW)) {
  567. void *tf_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL);
  568. if (tf_handle == nullptr) {
  569. GELOGW("dlopen fmk library [libfmk_parser.so] failed.");
  570. return;
  571. }
  572. GELOGI("plugin load libfmk_parser.so success.");
  573. } else if (FLAGS_framework == static_cast<int32_t>(domi::CAFFE)) {
  574. // What we are dealing with here is that the user modifies the caffe.proto scenario.
  575. // If no lib_Caffe_Parser.so is found under the plugin path, use the default lib_Caffe_Parser.so path.
  576. caffe_parser_path = caffe_parser_path.empty() ? "lib_caffe_parser.so" : caffe_parser_path;
  577. void *handle = dlopen(caffe_parser_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
  578. if (handle == nullptr) {
  579. GELOGW("dlopen failed, plugin name:%s. Message(%s).", caffe_parser_path.c_str(), dlerror());
  580. return;
  581. }
  582. GELOGI("plugin load %s success.", caffe_parser_path.c_str());
  583. // According to the dependency, the Caffe parsing module of the framework is loaded here( libfmk_parser.so).
  584. // (depend on the lib_caffe_parser.so)
  585. void *fmk_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL);
  586. if (fmk_handle == nullptr) {
  587. GELOGW("dlopen fmk library [libfmk_parser.so] failed.");
  588. if (dlclose(handle) != 0) {
  589. GELOGW("dlclose lib_caffe_parser.so failed.");
  590. }
  591. return;
  592. }
  593. GELOGI("plugin load libfmk_parser.so success.");
  594. } else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) {
  595. void *handle = dlopen("libfmk_onnx_parser.so", RTLD_NOW | RTLD_GLOBAL);
  596. if (handle == nullptr) {
  597. GELOGW("dlopen fmk library [libfmk_onnx_parser.so] failed.");
  598. return;
  599. }
  600. GELOGI("plugin load libfmk_onnx_parser.so success.");
  601. } else {
  602. GELOGW("Framework:%s is not support.",
  603. ge::TypeUtils::FmkTypeToSerialString(static_cast<domi::FrameworkType>(FLAGS_framework)).c_str());
  604. return;
  605. }
  606. return;
  607. }
  608. void LoadCustomOpLib(bool need_load_ops_plugin) {
  609. std::string plugin_path;
  610. GetCustomOpPath(plugin_path);
  611. vector<string> fileList;
  612. string caffe_parser_path = "";
  613. // whether there are files in the plugin so path
  614. GetPluginSoFileList(plugin_path, fileList, caffe_parser_path);
  615. // no file
  616. if (fileList.empty() && caffe_parser_path.empty()) {
  617. GELOGW("can not find any plugin file in plugin_path: %s", plugin_path.c_str());
  618. }
  619. LoadModelParserLib(caffe_parser_path);
  620. if (!need_load_ops_plugin) {
  621. GELOGI("No need to load ops plugin so.");
  622. return;
  623. }
  624. OpRegistry::Instance()->registrationDatas.clear();
  625. // load other so files except lib_caffe_parser.so in the plugin so path
  626. for (auto elem : fileList) {
  627. ge::StringUtils::Trim(elem);
  628. void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL);
  629. if (handle == nullptr) {
  630. GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror());
  631. } else {
  632. GELOGI("plugin load %s success.", elem.c_str());
  633. }
  634. }
  635. std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas;
  636. for (OpRegistrationData reg_data : registrationDatas) {
  637. if (reg_data.GetFrameworkType() == static_cast<domi::FrameworkType>(FLAGS_framework)) {
  638. (void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data);
  639. (void)OpRegistry::Instance()->Register(reg_data);
  640. }
  641. }
  642. }
  643. void SaveCustomCaffeProtoPath() {
  644. GELOGI("Enter save custom caffe proto path.");
  645. std::string path_base = ge::GELib::GetPath();
  646. GELOGI("path_base is %s", path_base.c_str());
  647. path_base = path_base.substr(0, path_base.rfind('/'));
  648. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  649. ge::GetParserContext().caffe_proto_path = path_base + "include/proto/";
  650. string customop_path;
  651. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  652. if (path_env != nullptr) {
  653. std::string path = path_env;
  654. customop_path = path + "/framework/custom/caffe/";
  655. GELOGI("Get custom proto path from env : %s", path_env);
  656. ge::GetParserContext().custom_proto_path = customop_path;
  657. return;
  658. }
  659. customop_path = path_base + "ops/framework/custom/caffe/";
  660. ge::GetParserContext().custom_proto_path = customop_path;
  661. return;
  662. }
  663. #endif
  664. Status CreateInputsForInference(const ge::Graph &graph, vector<ge::GeTensor> &inputs) {
  665. auto compute_graph = ge::GraphUtils::GetComputeGraph(graph);
  666. GE_CHECK_NOTNULL(compute_graph);
  667. for (ge::NodePtr &input_node : compute_graph->GetAllNodes()) {
  668. GE_CHECK_NOTNULL(input_node);
  669. ge::OpDescPtr op = input_node->GetOpDesc();
  670. GE_CHECK_NOTNULL(op);
  671. if (op->GetType() == ge::DATA) {
  672. GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size());
  673. ge::GeTensorDesc tensor = op->GetInputDesc(0);
  674. string data_op_name = op->GetName();
  675. GELOGI("Data op name is: %s", data_op_name.c_str());
  676. ge::GeShape data_shape;
  677. auto iter = domi::GetContext().input_dims.find(data_op_name);
  678. if (iter != domi::GetContext().input_dims.end()) {
  679. data_shape = ge::GeShape(iter->second);
  680. GELOGI("Data op get shape from Context.");
  681. } else {
  682. data_shape = tensor.GetShape();
  683. GELOGI("Data op get shape from InputDesc in geir graph.");
  684. }
  685. ge::DataType data_type = tensor.GetDataType();
  686. string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type);
  687. GELOGI("Data op get data type:%s from InputDesc in geir graph.", data_type_str.c_str());
  688. ge::GeTensor input_tensor;
  689. ge::GeTensorDesc desc(data_shape, ge::Format(domi::GetContext().format), data_type);
  690. input_tensor.SetTensorDesc(desc);
  691. inputs.push_back(input_tensor);
  692. }
  693. }
  694. GELOGI("Build ME model, inputs size is: %zu", inputs.size());
  695. return ge::SUCCESS;
  696. }
  697. domi::Status GenerateInfershapeJson() {
  698. if (!CheckInputFormat()) {
  699. GELOGE(ge::FAILED, "Check input_format failed");
  700. return domi::FAILED;
  701. }
  702. Status ret = GFlagUtils::CheckDumpInfershapeJsonFlags();
  703. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check flags failed!");
  704. ge::GeGenerator ge_generator;
  705. std::map<string, string> options;
  706. ge::Status geRet = ge_generator.Initialize(options, domi::GetContext());
  707. if (geRet != ge::SUCCESS) {
  708. DOMI_LOGE("GeGenerator initialize failed!");
  709. return domi::FAILED;
  710. }
  711. ge::Graph graph;
  712. std::map<string, string> atc_params;
  713. atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format));
  714. ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework,
  715. "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false);
  716. if (ret != ge::SUCCESS) {
  717. DOMI_LOGE("ATC Parse graph domi::FAILED");
  718. (void)ge_generator.Finalize();
  719. return domi::FAILED;
  720. }
  721. geRet = ge_generator.GenerateInfershapeGraph(graph);
  722. if (geRet != ge::SUCCESS) {
  723. DOMI_LOGE("ATC GenerateInfershapeJson failed");
  724. (void)ge_generator.Finalize();
  725. return domi::FAILED;
  726. }
  727. if (DumpInfershapeJson(graph, FLAGS_json.c_str()) != SUCCESS) {
  728. DOMI_LOGE("ATC DumpInfershapeJson failed");
  729. (void)ge_generator.Finalize();
  730. return domi::FAILED;
  731. }
  732. (void)ge_generator.Finalize();
  733. return ge::SUCCESS;
  734. }
  735. static Status ConvertModelToJson(int fwk_type, const string &model_file, const string &json_file) {
  736. Status ret = ge::SUCCESS;
  737. if (fwk_type == -1) {
  738. ret = ge::ConvertOmModelToJson(model_file.c_str(), json_file.c_str());
  739. return ret;
  740. }
  741. if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE) && (fwk_type != domi::ONNX)) {
  742. ErrorManager::GetInstance().ATCReportErrMessage(
  743. "E10001", {"parameter", "value", "reason"},
  744. {"--framework", std::to_string(fwk_type), kModelToJsonSupport});
  745. GELOGE(ge::FAILED, "Invalid value for --framework[%d], %s.", fwk_type, kModelToJsonSupport);
  746. ret = ge::FAILED;
  747. }
  748. if (FLAGS_dump_mode != "0" && FLAGS_dump_mode != "1") {
  749. ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"dump_mode"});
  750. GELOGE(ge::FAILED, "Input parameter[--dump_mode]'s value must be 1 or 0.");
  751. ret = ge::FAILED;
  752. }
  753. if (ret != ge::SUCCESS) return ret;
  754. // Need to save caffe.proto path
  755. SaveCustomCaffeProtoPath();
  756. if (FLAGS_dump_mode == "0") {
  757. // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so.
  758. LoadCustomOpLib(false);
  759. ret = ge::ConvertFwkModelToJson((domi::FrameworkType)fwk_type, model_file.c_str(), json_file.c_str());
  760. } else if (FLAGS_dump_mode == "1") {
  761. // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so and ops plugin so.
  762. LoadCustomOpLib(true);
  763. ret = GenerateInfershapeJson();
  764. }
  765. return ret;
  766. }
  767. domi::Status GenerateModel(std::map<string, string> &options, std::string output) {
  768. ge::GeGenerator ge_generator;
  769. ge::Status geRet = ge::SUCCESS;
  770. std::shared_ptr<ge::GELib> instance_ptr = ge::GELib::GetInstance();
  771. if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
  772. geRet = ge::GELib::Initialize(options);
  773. if (geRet != ge::SUCCESS) {
  774. DOMI_LOGE("GE initialize failed!");
  775. return domi::FAILED;
  776. }
  777. }
  778. geRet = ge_generator.Initialize(options, domi::GetContext());
  779. if (geRet != ge::SUCCESS) {
  780. DOMI_LOGE("GeGenerator initialize failed!");
  781. (void)ge::GELib::GetInstance()->Finalize();
  782. return domi::FAILED;
  783. }
  784. ge::Graph graph;
  785. std::vector<ge::GeTensor> inputs;
  786. if (FLAGS_framework == domi::MINDSPORE) {
  787. // load model from file
  788. ge::Model load_model = ge::Model("loadmodel", "version2");
  789. auto ret1 = load_model.LoadFromFile(FLAGS_model);
  790. if (ret1 != ge::GRAPH_SUCCESS) {
  791. ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model});
  792. DOMI_LOGE("Load model from %s failed, please check model file or "
  793. "input parameter[--framework] is correct", FLAGS_model.c_str());
  794. (void)ge_generator.Finalize();
  795. (void)ge::GELib::GetInstance()->Finalize();
  796. return domi::FAILED;
  797. }
  798. graph = load_model.GetGraph();
  799. GE_CHK_STATUS_EXEC(ge::InitDomiOmgContext(FLAGS_input_shape, FLAGS_input_format, "", is_dynamic_input),
  800. GELOGE(ge::FAILED, "ATC Generate call InitDomiOmgContext ret fail");
  801. (void)ge_generator.Finalize(); (void)ge::GELib::GetInstance()->Finalize(); return domi::FAILED);
  802. Status ret = CreateInputsForInference(graph, inputs);
  803. if (ret != ge::SUCCESS) {
  804. GELOGE(ge::FAILED, "create inputs for inference failed.");
  805. (void)ge_generator.Finalize();
  806. (void)ge::GELib::GetInstance()->Finalize();
  807. return domi::FAILED;
  808. }
  809. } else {
  810. std::map<string, string> atc_params;
  811. atc_params.insert(std::pair<string, string>("input_shape", FLAGS_input_shape));
  812. atc_params.insert(std::pair<string, string>("out_nodes", FLAGS_out_nodes));
  813. atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format));
  814. atc_params.insert(std::pair<string, string>("check_report", FLAGS_check_report));
  815. atc_params.insert(std::pair<string, string>("input_fp16_nodes", FLAGS_input_fp16_nodes));
  816. atc_params.insert(std::pair<string, string>("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout));
  817. atc_params.insert(std::pair<string, string>("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout));
  818. atc_params.insert(std::pair<string, string>("compress_weight_conf", FLAGS_compress_weight_conf));
  819. atc_params.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type));
  820. atc_params.insert(std::pair<string, string>("output", output));
  821. Status ret =
  822. ParseGraph(graph, atc_params, FLAGS_model.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType)FLAGS_framework,
  823. FLAGS_op_name_map.c_str(), FLAGS_target.c_str(), (ge::RunMode)FLAGS_mode, is_dynamic_input);
  824. // in ONLY_PRE_CHECK mode, pre-checking report has already saved in ParseGraph
  825. if (FLAGS_mode == ge::ONLY_PRE_CHECK) {
  826. (void)ge_generator.Finalize();
  827. (void)ge::GELib::GetInstance()->Finalize();
  828. if (ret != ge::SUCCESS) {
  829. DOMI_LOGE("ATC precheck fail.");
  830. return domi::FAILED;
  831. }
  832. return domi::SUCCESS;
  833. }
  834. if (ret != ge::SUCCESS) {
  835. DOMI_LOGE("ATC Parse graph domi::FAILED");
  836. DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case
  837. (void)ge_generator.Finalize();
  838. (void)ge::GELib::GetInstance()->Finalize();
  839. return domi::FAILED;
  840. }
  841. if (ge::SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) {
  842. DOMI_LOGE("Set output node info fail.");
  843. (void)ge_generator.Finalize();
  844. (void)ge::GELib::GetInstance()->Finalize();
  845. return domi::FAILED;
  846. }
  847. }
  848. geRet = ge_generator.GenerateOfflineModel(graph, output, inputs);
  849. if (geRet != ge::SUCCESS) {
  850. DOMI_LOGE("GE GenerateOfflineModel execute failed");
  851. DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case
  852. // checking error log)
  853. (void)ge_generator.Finalize();
  854. (void)ge::GELib::GetInstance()->Finalize();
  855. return domi::FAILED;
  856. }
  857. (void)ge_generator.Finalize();
  858. (void)ge::GELib::GetInstance()->Finalize();
  859. return ge::SUCCESS;
  860. }
  861. static void SetEnvForSingleOp(std::map<string, string> &options) {
  862. string flag_on = "1";
  863. string flag_off = "0";
  864. options.emplace(ge::GE_FE_FLAG, flag_on);
  865. options.emplace(ge::STREAM_NUM, "1"); // single op only use one stream
  866. options.emplace(ge::RUN_FLAG, flag_off);
  867. options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off);
  868. options.emplace(ge::SINGLE_OP_FLAG, flag_on);
  869. options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode);
  870. options.emplace(ge::SOC_VERSION, FLAGS_soc_version);
  871. options.emplace(ge::CORE_TYPE, FLAGS_core_type);
  872. options.emplace(ge::AICORE_NUM, FLAGS_aicore_num);
  873. options.emplace(ge::OP_SELECT_IMPL_MODE, FLAGS_op_select_implmode);
  874. options.emplace(ge::OPTYPELIST_FOR_IMPLMODE, FLAGS_optypelist_for_implmode);
  875. options.emplace(ge::AUTO_TUNE_MODE, FLAGS_auto_tune_mode);
  876. options.emplace(ge::GRAPH_MEMORY_MAX_SIZE, kGraphMemoryManagerMallocMaxSize);
  877. options.emplace(ge::OP_DEBUG_LEVEL, to_string(FLAGS_op_debug_level));
  878. options.emplace(ge::DEBUG_DIR, FLAGS_debug_dir);
  879. options.emplace(ge::OP_COMPILER_CACHE_DIR, FLAGS_op_compiler_cache_dir);
  880. options.emplace(ge::OP_COMPILER_CACHE_MODE, FLAGS_op_compiler_cache_mode);
  881. }
  882. domi::Status GenerateSingleOp(const std::string& json_file_path) {
  883. if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) {
  884. DOMI_LOGE("output path %s is not valid!", FLAGS_output.c_str());
  885. return domi::FAILED;
  886. }
  887. // check optypelist_for_implmode and op_select_implmode
  888. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  889. ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS,
  890. return ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!");
  891. std::map<string, string> options;
  892. // need to be changed when ge.ini plan is done
  893. SetEnvForSingleOp(options);
  894. auto ret = ge::GELib::Initialize(options);
  895. if (ret != ge::SUCCESS) {
  896. DOMI_LOGE("GE initialize failed!");
  897. return domi::FAILED;
  898. }
  899. ge::GeGenerator generator;
  900. ret = generator.Initialize(options, domi::GetContext());
  901. if (ret != SUCCESS) {
  902. DOMI_LOGE("GeGenerator initialize failed!");
  903. (void)ge::GELib::GetInstance()->Finalize();
  904. return domi::FAILED;
  905. }
  906. vector<ge::SingleOpBuildParam> build_params;
  907. if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) {
  908. DOMI_LOGE("parse single op json file failed");
  909. (void)generator.Finalize();
  910. (void)ge::GELib::GetInstance()->Finalize();
  911. return domi::FAILED;
  912. }
  913. int index = 0;
  914. for (auto &param : build_params) {
  915. string output_path;
  916. if (!FLAGS_output.empty()) {
  917. output_path = FLAGS_output + "/";
  918. }
  919. output_path += param.file_name;
  920. ret = generator.BuildSingleOpModel(param.op_desc, param.inputs, param.outputs, output_path);
  921. if (ret != SUCCESS) {
  922. DOMI_LOGE("Compile op failed. ge ret = %u, op index = %d", ret, index);
  923. ret = domi::FAILED;
  924. break;
  925. }
  926. GELOGI("Compile op success. op index = %d, output = %s", index, output_path.c_str());
  927. index += 1;
  928. }
  929. (void)generator.Finalize();
  930. (void)ge::GELib::GetInstance()->Finalize();
  931. return ret;
  932. }
  933. domi::Status GenerateOmModel() {
  934. if (!CheckInputFormat()) {
  935. GELOGE(ge::FAILED, "Check input_format failed");
  936. return domi::FAILED;
  937. }
  938. Status ret = GFlagUtils::CheckFlags();
  939. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED,
  940. "Check flags failed! Please check whether some atc params that include semicolons[;] use double "
  941. "quotation marks (\") to enclose each argument such as out_nodes, input_shape, dynamic_image_size");
  942. #if !defined(__ANDROID__) && !defined(ANDROID)
  943. // Load custom operator Library
  944. LoadCustomOpLib(true);
  945. SaveCustomCaffeProtoPath();
  946. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "check custom aicpu run so failed!");
  947. #endif
  948. const int f_stream_num = 1;
  949. std::map<string, string> options;
  950. options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(FLAGS_framework)));
  951. options.insert(std::pair<string, string>(string(ge::STREAM_NUM), to_string(f_stream_num)));
  952. options.insert(std::pair<string, string>(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf));
  953. options.insert(std::pair<string, string>(string(ge::ENCRYPT_MODE), to_string(FLAGS_encrypt_mode)));
  954. options.insert(std::pair<string, string>(string(ge::EK_FILE), FLAGS_encrypt_key));
  955. options.insert(std::pair<string, string>(string(ge::CERT_FILE), FLAGS_certificate));
  956. options.insert(std::pair<string, string>(string(ge::HW_KEY_FILE), FLAGS_hardware_key));
  957. options.insert(std::pair<string, string>(string(ge::PRIVATE_KEY_FILE), FLAGS_private_key));
  958. options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes));
  959. options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf));
  960. options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode));
  961. options.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0)));
  962. options.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0)));
  963. if (!FLAGS_output_type.empty()) {
  964. options.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type));
  965. }
  966. options.insert(std::pair<string, string>(string(ge::OP_SELECT_IMPL_MODE), FLAGS_op_select_implmode));
  967. options.insert(std::pair<string, string>(string(ge::OPTYPELIST_FOR_IMPLMODE), FLAGS_optypelist_for_implmode));
  968. if (!FLAGS_input_fp16_nodes.empty()) {
  969. GELOGI("FLAGS_input_fp16_nodes : %s .", FLAGS_input_fp16_nodes.c_str());
  970. options.insert(std::pair<string, string>(ge::INPUT_FP16_NODES, FLAGS_input_fp16_nodes));
  971. }
  972. options.insert(std::pair<string, string>(string(ge::AUTO_TUNE_MODE), FLAGS_auto_tune_mode));
  973. options.insert(
  974. std::pair<string, string>(string(ge::OPTION_EXEC_DISABLE_REUSED_MEMORY), to_string(FLAGS_disable_reuse_memory)));
  975. options.insert(std::pair<string, string>(string(ge::SOC_VERSION), FLAGS_soc_version));
  976. options.insert(std::pair<string, string>(string(ge::CORE_TYPE), FLAGS_core_type));
  977. options.insert(std::pair<string, string>(string(ge::AICORE_NUM), FLAGS_aicore_num));
  978. options.insert(std::pair<string, string>(string(ge::BUFFER_OPTIMIZE), FLAGS_buffer_optimize));
  979. options.insert(std::pair<string, string>(string(ge::ENABLE_SMALL_CHANNEL), FLAGS_enable_small_channel));
  980. options.insert(std::pair<string, string>(string(ge::FUSION_SWITCH_FILE), FLAGS_fusion_switch_file));
  981. options.insert(std::pair<string, string>(string(ge::ENABLE_COMPRESS_WEIGHT),
  982. (FLAGS_enable_compress_weight == "true") ?
  983. ge::kEnableCompressWeightTrue : ge::kEnableCompressWeightFalse));
  984. options.insert(std::pair<string, string>(string(ge::GRAPH_MEMORY_MAX_SIZE), kGraphMemoryManagerMallocMaxSize));
  985. options.insert(std::pair<string, string>(string(ge::ENABLE_SINGLE_STREAM), FLAGS_enable_single_stream));
  986. options.insert(std::pair<string, string>(string(ge::DEBUG_DIR), FLAGS_debug_dir));
  987. options.insert(std::pair<string, string>(string(ge::OP_COMPILER_CACHE_DIR), FLAGS_op_compiler_cache_dir));
  988. options.insert(std::pair<string, string>(string(ge::OP_COMPILER_CACHE_MODE), FLAGS_op_compiler_cache_mode));
  989. SetDynamicInputSizeOptions();
  990. if (!FLAGS_save_original_model.empty()) {
  991. options.insert(std::pair<string, string>(string(ge::SAVE_ORIGINAL_MODEL), FLAGS_save_original_model));
  992. options.insert(std::pair<string, string>(string(ge::ORIGINAL_MODEL_FILE), FLAGS_output + "_original.om"));
  993. }
  994. options.insert(std::pair<string, string>(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level)));
  995. // set enable scope fusion passes
  996. SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes);
  997. // print atc option map
  998. ge::PrintOptionMap(options, "atc option");
  999. // When the ATC module is transferred to a model, the suffix ".om" is automatically added to the model name
  1000. FLAGS_output = FLAGS_output + ".om";
  1001. ret = GenerateModel(options, FLAGS_output);
  1002. if (ret != domi::SUCCESS) {
  1003. return domi::FAILED;
  1004. }
  1005. return domi::SUCCESS;
  1006. }
  1007. domi::Status ConvertModelToJson() {
  1008. Status ret = GFlagUtils::CheckConverJsonParamFlags();
  1009. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check convert json params flags failed!");
  1010. ret = ConvertModelToJson(FLAGS_framework, FLAGS_om, FLAGS_json);
  1011. GE_IF_BOOL_EXEC(ret != domi::SUCCESS, return domi::FAILED);
  1012. return domi::SUCCESS;
  1013. }
  1014. bool CheckRet(domi::Status ret) {
  1015. if (ret != domi::SUCCESS) {
  1016. if (FLAGS_mode == ONLY_PRE_CHECK) {
  1017. GELOGW("ATC precheck failed.");
  1018. } else if (FLAGS_mode == GEN_OM_MODEL) {
  1019. GELOGW("ATC generate offline model failed.");
  1020. } else if (FLAGS_mode == MODEL_TO_JSON) {
  1021. GELOGW("ATC convert model to json file failed.");
  1022. } else if (FLAGS_mode == PBTXT_TO_JSON) {
  1023. GELOGW("ATC convert pbtxt to json file failed.");
  1024. } else {
  1025. return false;
  1026. }
  1027. return false;
  1028. }
  1029. if (FLAGS_mode == ONLY_PRE_CHECK) {
  1030. GELOGI("ATC precheck success.");
  1031. } else if (FLAGS_mode == GEN_OM_MODEL) {
  1032. GELOGI("ATC generate offline model success.");
  1033. } else if (FLAGS_mode == MODEL_TO_JSON) {
  1034. GELOGI("ATC convert model to json file success.");
  1035. } else if (FLAGS_mode == PBTXT_TO_JSON) {
  1036. GELOGI("ATC convert pbtxt to json file success.");
  1037. }
  1038. return true;
  1039. }
  1040. domi::Status ConvertPbtxtToJson() {
  1041. Status ret = GFlagUtils::CheckConverJsonParamFlags();
  1042. if (ret != domi::SUCCESS) {
  1043. GELOGE(ge::FAILED, "Check convert json params flags failed!");
  1044. return domi::FAILED;
  1045. }
  1046. ret = ge::ConvertPbtxtToJson(FLAGS_om.c_str(), FLAGS_json.c_str());
  1047. if (ret != domi::SUCCESS) {
  1048. GELOGE(ge::FAILED, "ConvertPbtxtToJson fail.");
  1049. return domi::FAILED;
  1050. }
  1051. return domi::SUCCESS;
  1052. }
  1053. int init(int argc, char* argv[]) {
  1054. GFlagUtils::InitGFlag(argc, argv);
  1055. // set log level
  1056. int ret = -1;
  1057. const std::set<string> log_level = {"null", "debug", "info", "warning", "error"};
  1058. if (log_level.count(FLAGS_log) == 0) {
  1059. std::cout << "E10010: invalid value for --log:" << FLAGS_log
  1060. <<", only support debug, info, warning, error, null"<< std::endl;
  1061. return ret;
  1062. }
  1063. ret = ge::CheckLogParamValidAndSetLogLevel(FLAGS_log);
  1064. if (ret != 0) {
  1065. return ret;
  1066. }
  1067. std::string path_base = ge::GELib::GetPath();
  1068. ret = ErrorManager::GetInstance().Init(path_base);
  1069. if (ret != 0) {
  1070. DOMI_LOGE("ErrorManager init fail !");
  1071. return ret;
  1072. }
  1073. return 0;
  1074. }
  1075. long GetMemInfo(const std::string &key) {
  1076. std::string file_path = "/proc/meminfo";
  1077. std::ifstream fs(file_path, std::ifstream::in);
  1078. if (!fs.is_open()) {
  1079. GELOGW("Can not open %s .", file_path.c_str());
  1080. return 0;
  1081. }
  1082. std::string line;
  1083. while (getline(fs, line)) { // line not with \n
  1084. if (line.find(key) != std::string::npos) {
  1085. GELOGI("Find mem [%s] info line [%s]", key.c_str(), line.c_str());
  1086. fs.close();
  1087. size_t pos = line.find(":");
  1088. if (pos == std::string::npos) {
  1089. return 0;
  1090. }
  1091. std::string current_mem_info_str = line.substr(pos + 1);
  1092. ge::StringUtils::Trim(current_mem_info_str);
  1093. GELOGI("Find mem [%s] info [%s].", key.c_str(), current_mem_info_str.c_str());
  1094. return stol(current_mem_info_str);
  1095. }
  1096. }
  1097. fs.close(); // close the file
  1098. return 0;
  1099. }
  1100. bool CheckMemInfo() {
  1101. if (FLAGS_auto_tune_mode.empty()) {
  1102. return true;
  1103. }
  1104. // only check current available mem when auto_tune_mode is set.
  1105. long current_mem_available = GetMemInfo("MemAvailable");
  1106. GELOGI("Get mem available [%lu].", current_mem_available);
  1107. std::cout << "Current available mem is " << current_mem_available << "kB." << std::endl;
  1108. if ((current_mem_available > 0) && (current_mem_available < kMinAvailableMem)) {
  1109. GELOGE(ge::PARAM_INVALID, "Current available mem [%lu] can not be smaller than [%lu] .",
  1110. current_mem_available, kMinAvailableMem);
  1111. ErrorManager::GetInstance().ATCReportErrMessage("E10044", {"value", "min_value"},
  1112. {to_string(current_mem_available), to_string(kMinAvailableMem)});
  1113. return false;
  1114. }
  1115. return true;
  1116. }
  1117. int main(int argc, char* argv[]) {
  1118. Status ret = domi::SUCCESS;
  1119. std::cout << "ATC start working now, please wait for a moment." << std::endl;
  1120. // Initialize
  1121. if (init(argc, argv) != 0) {
  1122. std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl;
  1123. return -1;
  1124. }
  1125. do {
  1126. if (!CheckMemInfo()) {
  1127. GELOGE(ge::PARAM_INVALID, "Current available mem is too small");
  1128. ret = domi::FAILED;
  1129. break;
  1130. }
  1131. if (!FLAGS_singleop.empty()) {
  1132. ret = GenerateSingleOp(FLAGS_singleop);
  1133. break;
  1134. }
  1135. // default mode(mode:0), Open source model to model
  1136. if (GEN_OM_MODEL == FLAGS_mode || ONLY_PRE_CHECK == FLAGS_mode) {
  1137. GE_IF_BOOL_EXEC(GenerateOmModel() != domi::SUCCESS, ret = domi::FAILED; break);
  1138. } else if (MODEL_TO_JSON == FLAGS_mode) { // Mode 1, transfer model to JSON
  1139. GE_CHK_BOOL_EXEC(ConvertModelToJson() == domi::SUCCESS, ret = domi::FAILED;
  1140. break, "ATC ConvertJson execute failed!!");
  1141. } else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) {
  1142. GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED;
  1143. break, "ATC convert pbtxt to json execute failed!!");
  1144. } else {
  1145. ErrorManager::GetInstance().ATCReportErrMessage(
  1146. "E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport});
  1147. GELOGE(ge::PARAM_INVALID, "Invalid value for --mode[%d], %s.", FLAGS_mode, kModeSupport);
  1148. ret = domi::FAILED;
  1149. break;
  1150. }
  1151. } while (0);
  1152. if (!CheckRet(ret)) {
  1153. std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl;
  1154. int result = ErrorManager::GetInstance().OutputErrMessage(STDOUT_FILENO);
  1155. if (result != 0) {
  1156. DOMI_LOGE("ErrorManager outputErrMessage fail !");
  1157. }
  1158. GELOGI("Current mem available mem is [%lu]", GetMemInfo("MemAvailable"));
  1159. return ret;
  1160. } else {
  1161. std::cout << "ATC run success, welcome to the next use." << std::endl;
  1162. (void)ErrorManager::GetInstance().OutputMessage(STDOUT_FILENO);
  1163. return 0;
  1164. }
  1165. } /*lint +e530*/

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示