You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.cc 62 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <dlfcn.h>
  18. #include <gflags/gflags.h>
  19. #include <sys/types.h>
  20. #include <unistd.h>
  21. #include <cctype>
  22. #include <climits>
  23. #include <cstdlib>
  24. #include <iostream>
  25. #include "common/gflags_util.h"
  26. #include "common/util.h"
  27. #include "common/util/error_manager/error_manager.h"
  28. #include "framework/common/debug/ge_log.h"
  29. #include "ge/ge_api.h"
  30. #include "generator/ge_generator.h"
  31. #include "graph/anchor.h"
  32. #include "graph/debug/ge_attr_define.h"
  33. #include "graph/graph.h"
  34. #include "graph/op_desc.h"
  35. #include "graph/utils/graph_utils.h"
  36. #include "graph/utils/type_utils.h"
  37. #include "init/gelib.h"
  38. #include "ir_build/atc_ir_common.h"
  39. #include "omg/omg.h"
  40. #include "omg/parser/parser_factory.h"
  41. #include "omg/parser/parser_inner_ctx.h"
  42. #include "parser/common/register_tbe.h"
  43. #include "register/op_registry.h"
  44. #include "single_op_parser.h"
  45. #include "keep_dtype_option.h"
  46. using domi::BuildMode;
  47. using domi::OpRegistrationData;
  48. using domi::OpRegistry;
  49. using domi::Status;
  50. using domi::SUCCESS;
  51. using ge::GEN_OM_MODEL;
  52. using ge::GflagsUtils;
  53. using ge::MODEL_TO_JSON;
  54. using ge::ONLY_PRE_CHECK;
  55. using ge::ParseInputShape;
  56. using ge::PBTXT_TO_JSON;
  57. using std::map;
  58. using std::pair;
  59. using std::shared_ptr;
  60. using std::string;
  61. using std::vector;
  62. static bool is_dynamic_input = false;
  63. const char *const kModeSupport = "only support 0(model to framework model), "
  64. "1(framework model to json), 3(only pre-check), "
  65. "5(pbtxt to json), 6(display model info)";
  66. const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow) 5(Onnx)";
  67. static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model";
  68. static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model";
  69. static const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model";
  70. // limit available mem size 2G
  71. const long kMinAvailableMem = 2097152; // 2 * 1024 * 1024
  72. DEFINE_string(model, "", "The model file.");
  73. DEFINE_string(output, "", "The output file path&name.");
  74. DEFINE_int32(framework, -1, "Framework type(0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx).");
  75. DEFINE_string(weight, "", "Optional; weight file. Required when framework is Caffe.");
  76. DEFINE_string(input_shape, "",
  77. "Optional; shape of input data. Required when framework is caffe "
  78. "or TensorFLow or MindSpore or Onnx. "
  79. "Format: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"");
  80. DEFINE_bool(h, false, "show this help message");
  81. DEFINE_string(cal_conf, "", "Optional; the calibration config file.");
  82. DEFINE_string(insert_op_conf, "", "Optional; the config file to insert new op, for example AIPP op.");
  83. DEFINE_string(op_name_map, "", "Optional; custom op name mapping file.");
  84. DEFINE_string(target, "", "Optional; mini.");
  85. DEFINE_string(om, "", "The model file to be converted to json.");
  86. DEFINE_string(json, "", "The output json file path&name which is converted from a model.");
  87. DEFINE_int32(mode, 0,
  88. "Optional; run mode, 0(default): model => framework model; 1: "
  89. "framework model => json; 3: only pre-check; 5: txt => json.");
  90. #if !defined(__ANDROID__) && !defined(ANDROID)
  91. DEFINE_int32(encrypt_mode, -1, "Optional; the encrypt flag. 0: encrypt; -1(default): not encrypt");
  92. DEFINE_string(encrypt_key, "", "Optional; the encrypt_key file.");
  93. DEFINE_string(certificate, "", "Optional; the certificate file.");
  94. DEFINE_string(hardware_key, "", "Optional; the ISV key file.");
  95. DEFINE_string(private_key, "", "Optional; the private key file.");
  96. #endif
  97. DEFINE_string(out_nodes, "",
  98. "Optional; output nodes designated by users."
  99. "Format: \"node_name1:0;node_name1:1;node_name2:0\"");
  100. DEFINE_string(precision_mode, "force_fp16",
  101. "Optional; precision mode."
  102. "Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype.");
  103. DEFINE_string(keep_dtype, "",
  104. "Optional; config file to specify the precision used by the operator during compilation.");
  105. DEFINE_string(input_format, "",
  106. "Optional; input_format, format of input data, NCHW;NHWC."
  107. "Format:\"NHWC\"");
  108. DEFINE_string(check_report, "check_result.json", "Optional; the pre-checking report file.");
  109. DEFINE_string(input_fp16_nodes, "",
  110. "Optional; input node datatype is fp16 and format is NC1HWC0."
  111. "Format:\"node_name1;node_name2\"");
  112. DEFINE_string(is_output_adjust_hw_layout, "",
  113. "Optional; Net output node's datatype is fp16 and format is "
  114. "NC1HWC0, or not."
  115. "Format:\"false,true,false,true\"");
  116. DEFINE_string(is_input_adjust_hw_layout, "",
  117. "Optional; Intput node's datatype is fp16 and format is "
  118. "NC1HWC0, or not."
  119. "Format:\"false,true,false,true\"");
  120. DEFINE_string(output_type, "",
  121. "Optional; output type! "
  122. "Support FP32,FP16,INT8,INT16,UINT16,UINT8,INT32,INT64,UINT32,UINT64,DOUBLE.");
  123. DEFINE_string(op_select_implmode, "",
  124. "Optional; op select implmode! "
  125. "Support high_precision, high_performance.");
  126. DEFINE_string(optypelist_for_implmode, "",
  127. "Optional; Nodes need use implmode selected in op_select_implmode "
  128. "Format:\"node_name1,node_name2\"");
  129. DEFINE_string(singleop, "", "Optional; If set, generate single op model with the given json file.");
  130. DEFINE_int32(disable_reuse_memory, 0, "Optional; If set to 1, disable reuse memory when generating if.");
  131. DEFINE_string(auto_tune_mode, "", "Optional; Set tune mode.");
  132. DEFINE_string(soc_version, "", "The soc version.");
  133. DEFINE_string(core_type, "AiCore", "Optional; If set to VectorCore, only use vector core.");
  134. DEFINE_string(aicore_num, "", "Optional; Set aicore num");
  135. DEFINE_string(buffer_optimize, "l2_optimize", "Optional; buffer optimize");
  136. DEFINE_string(fusion_switch_file, "", "Optional; Set fusion switch file path");
  137. DEFINE_string(save_original_model, "", "Optional; enable output original offline model. false(default)");
  138. DEFINE_string(dynamic_batch_size, "",
  139. "Optional; If set, generate dynamic multi batch model. "
  140. "Different batch sizes are split by ','."
  141. "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one.");
  142. DEFINE_string(dynamic_image_size, "",
  143. "Optional; If set, generate dynamic multi image size model."
  144. "Different groups of image size are split by ';',"
  145. "while different dimensions of each group are split by ','."
  146. "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one.");
  147. DEFINE_string(dynamic_dims, "",
  148. "Optional; If set, generate dynamic input size model. "
  149. "Different groups of size are split by ';', while different dimensions of each group are split by ','."
  150. "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one.");
  151. DEFINE_string(enable_small_channel, "0", "Optional; If set to 1, small channel is enabled.");
  152. DEFINE_string(enable_compress_weight, "false",
  153. "Optional; enable compress weight. true: enable; false(default): disable");
  154. DEFINE_string(compress_weight_conf, "", "Optional; the config file to compress weight");
  155. DEFINE_string(enable_single_stream, "", "Optional; enable single stream. true: enable; false(default): disable");
  156. DEFINE_string(log, "null", "Optional; generate atc log. Support debug, info, warning, error, null");
  157. DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0.");
  158. DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug;"
  159. "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler");
  160. DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass,"
  161. "multiple names can be set and separated by ','.");
  162. DEFINE_string(debug_dir, "", "Optional; the path to save the intermediate files of operator compilation");
  163. DEFINE_string(op_compiler_cache_dir, "", "Optional; the path to cache operator compilation files");
  164. DEFINE_string(op_compiler_cache_mode, "", "Optional; choose the operator compiler cache mode");
  165. DEFINE_string(mdl_bank_path, "", "Optional; model bank path");
  166. DEFINE_string(op_bank_path, "", "Optional; op bank path");
  167. DEFINE_string(display_model_info, "0", "Optional; display model info");
  168. class GFlagUtils {
  169. public:
  170. /**
  171. * @name InitGFlag
  172. * @brief initialize gflag
  173. * @return void
  174. */
  175. static void InitGFlag(int argc, char *argv[]) {
  176. // -help
  177. gflags::SetUsageMessage(
  178. "usage: ./atc <args>\n"
  179. "generate offline model example:\n"
  180. "./atc --model=./alexnet.prototxt --weight=./alexnet.caffemodel \n"
  181. "--framework=0 --output=./domi \n"
  182. "generate offline model for single op example:\n"
  183. "./atc --singleop=./op_list.json --output=./op_model \n"
  184. "===== Basic Functionality =====\n"
  185. "[General]\n"
  186. " --h/help Show this help message\n"
  187. " --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format; "
  188. "3: only pre-check; 5: convert ge dump txt file to JSON format; 6: display model info\n"
  189. "\n[Input]\n"
  190. " --model Model file\n"
  191. " --weight Weight file. Required when framework is Caffe\n"
  192. " --om The model file to be converted to json\n"
  193. " --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n"
  194. " --input_format Format of input data. E.g.: \"NCHW\"\n"
  195. " --input_shape Shape of input data. Separate multiple nodes with semicolons (;). "
  196. "Use double quotation marks (\") to enclose each argument.\n"
  197. " E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n"
  198. " --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n"
  199. " --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;). "
  200. "Use double quotation marks (\") to enclose each argument.\n"
  201. " E.g.: \"imagesize1_height,imagesize1_width;imagesize2_height,imagesize2_width\"\n"
  202. " --dynamic_dims Set dynamic dims. Separate multiple nodes with semicolons (;). "
  203. "Use double quotation marks (\") to enclose each argument.\n"
  204. " E.g.: \"dims1_n1,dims1_n2;dims2_n1,dims2_n2\"\n"
  205. " --singleop Single op definition file. atc will generate offline "
  206. "model(s) for single op if --singleop is set.\n"
  207. "\n[Output]\n"
  208. " --output Output file path&name(needn't suffix, will add .om automatically). \n"
  209. " If --singleop is set, this arg specifies the directory to "
  210. "which the single op offline model will be generated\n"
  211. " --output_type Set net output type. Support FP32, FP16, UINT8. "
  212. "E.g.: FP16, indicates that all out nodes are set to FP16.\n"
  213. " \"node1:0:FP16;node2:1:FP32\", indicates setting the datatype of multiple out nodes.\n"
  214. " --check_report The pre-checking report file. Default value is: \"check_result.json\"\n"
  215. " --json The output json file path&name which is converted from a model\n"
  216. "\n[Target]\n"
  217. " --soc_version The soc version.\n"
  218. " --core_type Set core type AiCore or VectorCore. VectorCore: use vector core. "
  219. "Default value is: AiCore\n"
  220. " --aicore_num Set aicore num\n"
  221. "===== Advanced Functionality =====\n"
  222. "[Feature]\n"
  223. " --out_nodes Output nodes designated by users. Separate multiple nodes with semicolons (;)."
  224. "Use double quotation marks (\") to enclose each argument.\n"
  225. " E.g.: \"node_name1:0;node_name1:1;node_name2:0\"\n"
  226. " --input_fp16_nodes Input node datatype is fp16. Separate multiple nodes with semicolons (;). "
  227. "Use double quotation marks (\") to enclose each argument. "
  228. "E.g.: \"node_name1;node_name2\"\n"
  229. " --insert_op_conf Config file to insert new op\n"
  230. " --op_name_map Custom op name mapping file\n"
  231. " Note: A semicolon(;) cannot be included in each "
  232. "path, otherwise the resolved path will not match the expected one.\n"
  233. " --is_input_adjust_hw_layout Intput node datatype is fp16 and format is "
  234. "NC1HWC0, used with input_fp16_nodes. E.g.: \"true,true,false,true\"\n"
  235. " --is_output_adjust_hw_layout Net output node datatype is fp16 and format is "
  236. "NC1HWC0, used with out_nodes. E.g.: \"true,true,false,true\"\n"
  237. "\n[Model Tuning]\n"
  238. " --disable_reuse_memory The switch of reuse memory. Default value is : 0. "
  239. "0 means reuse memory, 1 means do not reuse memory.\n"
  240. " --fusion_switch_file Set fusion switch file path\n"
  241. " --enable_scope_fusion_passes validate the non-general scope fusion passes, "
  242. "multiple names can be set and separated by ','. E.g.: ScopePass1,ScopePass2,...\n"
  243. " --enable_single_stream Enable single stream. true: enable; false(default): disable\n"
  244. " --enable_small_channel Set enable small channel. 0(default): disable; 1: enable\n"
  245. " --enable_compress_weight Enable compress weight. true: enable; false(default): disable\n"
  246. " --compress_weight_conf Config file to compress weight\n"
  247. " --buffer_optimize Set buffer optimize. Support \"l2_optimize\" (default), "
  248. "\"l1_optimize\", \"off_optimize\"\n"
  249. " --mdl_bank_path Set the path of the custom repository generated after model tuning.\n"
  250. "\n[Operator Tuning]\n"
  251. " --precision_mode precision mode, support force_fp16(default), allow_mix_precision, "
  252. "allow_fp32_to_fp16, must_keep_origin_dtype.\n"
  253. " --keep_dtype Retains the precision of certain operators in inference "
  254. "scenarios by using a configuration file.\n"
  255. " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n"
  256. " --op_bank_path Set the path of the custom repository generated after operator tuning with Auto Tune.\n"
  257. " --op_select_implmode Set op select implmode. Support high_precision, high_performance. "
  258. "default: high_performance\n"
  259. " --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n"
  260. " Separate multiple nodes with commas (,). Use double quotation marks (\") "
  261. "to enclose each argument. E.g.: \"node_name1,node_name2\"\n"
  262. " --op_debug_level Debug enable for TBE operator building.\n"
  263. " 0 (default): Disable debug; 1: Enable TBE pipe_all, "
  264. "and generate the operator CCE file and Python-CCE mapping file (.json);\n"
  265. " 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file "
  266. "(.json), and enable the CCE compiler -O0-g.\n"
  267. " 3: Disable debug, and keep generating kernel file (.o and .json)\n"
  268. "\n[Debug]\n"
  269. " --save_original_model Control whether to output original model. E.g.: true: output original model\n"
  270. " --log Generate log with level. Support debug, info, warning, error, null\n"
  271. " --dump_mode The switch of dump json with shape, to be used with mode 1. "
  272. "0(default): disable; 1: enable.\n"
  273. " --debug_dir Set the save path of operator compilation intermediate files.\n"
  274. "Default value: ./kernel_meta\n"
  275. " --op_compiler_cache_dir Set the save path of operator compilation cache files.\n"
  276. "Default value: $HOME/atc_data\n"
  277. " --op_compiler_cache_mode Set the operator compilation cache mode."
  278. "Options are disable(default), enable and force(force to refresh the cache)\n"
  279. " --display_model_info enable for display model info; 0(default): close display, 1: open display");
  280. gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
  281. // Using gflags to analyze input parameters
  282. GflagsUtils::ChangeHelpFlags(FLAGS_h);
  283. gflags::HandleCommandLineHelpFlags();
  284. }
  285. static Status CheckDumpInfershapeJsonFlags() {
  286. Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight);
  287. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED,
  288. "check custom aicpu run so failed!");
  289. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  290. FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"),
  291. return domi::FAILED, "Input parameter[--weight]'s value[%s] is invalid!",
  292. FLAGS_weight.c_str());
  293. return domi::SUCCESS;
  294. }
  295. static Status CheckFlags() {
  296. Status ret = ge::SUCCESS;
  297. // No model file information passed in
  298. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  299. FLAGS_model == "",
  300. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"model"});
  301. ret = ge::FAILED, "Input parameter[--model]'s value is empty!");
  302. // check param disable_reuse_memory
  303. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  304. ge::CheckDisableReuseMemoryParamValid(to_string(FLAGS_disable_reuse_memory)) != ge::SUCCESS,
  305. ret = ge::FAILED, "check disable_reuse_memory failed!");
  306. // check optypelist_for_implmode and op_select_implmode
  307. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  308. ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode,
  309. FLAGS_op_select_implmode) != ge::SUCCESS,
  310. ret = ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!");
  311. // No output file information passed in
  312. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  313. FLAGS_mode == GEN_OM_MODEL && FLAGS_output == "",
  314. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"output"});
  315. ret = ge::FAILED, "Input parameter[--output]'s value is empty!");
  316. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  317. CheckFrameWorkValid(FLAGS_framework, FLAGS_weight) != ge::SUCCESS,
  318. ret = ge::FAILED,
  319. "CheckFrameWorkValid failed");
  320. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  321. ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size,
  322. FLAGS_dynamic_dims, FLAGS_input_shape,
  323. FLAGS_input_format, is_dynamic_input) != ge::SUCCESS,
  324. ret = ge::FAILED, "check dynamic size(batch size, image size or dims) failed!");
  325. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  326. !FLAGS_insert_op_conf.empty() && !FLAGS_dynamic_dims.empty(),
  327. ErrorManager::GetInstance().ATCReportErrMessage("E10001",
  328. {"parameter", "value", "reason"},
  329. {"--insert_op_conf", FLAGS_insert_op_conf,
  330. "dynamic dims function does not support aipp"});
  331. ret = ge::FAILED, "dynamic dims function does not support aipp");
  332. #if !defined(__ANDROID__) && !defined(ANDROID)
  333. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!CheckEncryptModeValid(FLAGS_encrypt_mode), ret = ge::FAILED,
  334. "encrypt_mode %d not valid!!", FLAGS_encrypt_mode);
  335. if (FLAGS_encrypt_mode == 0) { // Encryption mode
  336. GELOGI("ge will run with encrypt!");
  337. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_encrypt_key), ret = ge::FAILED,
  338. "encrypt_key file not found!!");
  339. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_certificate), ret = ge::FAILED,
  340. "certificate file not found!!");
  341. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_hardware_key), ret = ge::FAILED,
  342. "hardware_key file not found!!");
  343. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_private_key), ret = ge::FAILED,
  344. "private_key file not found!!");
  345. } else { // No encryption
  346. GELOGI("ge will run without encrypt!");
  347. }
  348. #endif
  349. /**
  350. * Check the validity of the I / O file path
  351. */
  352. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  353. FLAGS_model != "" && !ge::CheckInputPathValid(FLAGS_model, "--model"), ret = ge::FAILED,
  354. "model file %s not found!!", FLAGS_model.c_str());
  355. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  356. FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"),
  357. ret = ge::FAILED, "weight file %s not found!!",
  358. FLAGS_weight.c_str());
  359. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  360. FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "--cal_conf"),
  361. ret = ge::FAILED, "calibration config file %s not found!!",
  362. FLAGS_cal_conf.c_str());
  363. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  364. FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "--op_name_map"),
  365. ret = ge::FAILED, "op config file %s not found!!",
  366. FLAGS_op_name_map.c_str());
  367. GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(FLAGS_insert_op_conf)) == ge::SUCCESS,
  368. ret = ge::FAILED, "check insert op conf failed!");
  369. GE_CHK_BOOL_EXEC(ge::CheckCompressWeightParamValid(
  370. FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS,
  371. ret = ge::FAILED, "check compress weight failed!");
  372. GE_CHK_BOOL_EXEC(ge::CheckKeepTypeParamValid(FLAGS_keep_dtype) == ge::SUCCESS,
  373. ret = ge::FAILED, "check keep dtype failed!");
  374. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  375. !ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED,
  376. "check_report file %s not found!!", FLAGS_check_report.c_str());
  377. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  378. FLAGS_mode == GEN_OM_MODEL && FLAGS_output != "" &&
  379. (!ge::CheckOutputPathValid(FLAGS_output, "--output") || !CheckPathWithName(FLAGS_output)),
  380. ret = ge::FAILED, "output path %s is not valid!!", FLAGS_output.c_str());
  381. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  382. FLAGS_save_original_model != "" &&
  383. FLAGS_save_original_model != "true" &&
  384. FLAGS_save_original_model != "false",
  385. ErrorManager::GetInstance().ATCReportErrMessage(
  386. "E10005", {"parameter", "value"}, {"save_original_model", FLAGS_save_original_model});
  387. ret = ge::FAILED,
  388. "Input parameter[--save_original_model]'s value[%s] must be true or false.",
  389. FLAGS_save_original_model.c_str());
  390. GE_CHK_BOOL_EXEC(ge::CheckBufferOptimizeParamValid(FLAGS_buffer_optimize) == ge::SUCCESS,
  391. ret = ge::FAILED, "check output type failed!");
  392. GE_CHK_BOOL_EXEC(
  393. ge::CheckEnableSingleStreamParamValid(std::string(FLAGS_enable_single_stream)) == ge::SUCCESS,
  394. ret = ge::FAILED, "check enable single stream failed!");
  395. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((FLAGS_display_model_info != "0") && (FLAGS_display_model_info != "1"),
  396. ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"display_model_info"});
  397. ret = ge::FAILED, "Input parameter[--display_model_info]'s value must be 1 or 0.");
  398. return ret;
  399. }
  400. /**
  401. * Verifying the parameters of converting model to JSON
  402. * 1. Fmk_model
  403. * 2. out_json
  404. **/
  405. static Status CheckConverJsonParamFlags() {
  406. Status ret = ge::SUCCESS;
  407. // No model path passed in
  408. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "",
  409. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"});
  410. ret = ge::FAILED,
  411. "Input parameter[--om]'s value is empty!!");
  412. // JSON path not passed in
  413. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_json == "",
  414. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"json"});
  415. ret = ge::FAILED,
  416. "Input parameter[--json]'s value is empty!!");
  417. // Check if the model path is valid
  418. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  419. FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"),
  420. ret = ge::FAILED,
  421. "model file path is invalid: %s.", FLAGS_om.c_str());
  422. // Check whether the JSON path is valid
  423. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  424. FLAGS_json != "" && !ge::CheckOutputPathValid(FLAGS_json, "--json"),
  425. ret = ge::FAILED,
  426. "json file path is invalid: %s.", FLAGS_json.c_str());
  427. return ret;
  428. }
  429. /**
  430. * Check command line parameters for explicit settings
  431. * true: Explicit setup
  432. * false: Not set up
  433. * */
  434. static bool CheckFlagSet(string flag) {
  435. gflags::CommandLineFlagInfo info;
  436. return !(gflags::GetCommandLineFlagInfo(flag.c_str(), &info) && info.is_default);
  437. }
  438. private:
  439. static bool CheckEncryptModeValid(const int encrypt_mode) {
  440. #if !defined(__ANDROID__) && !defined(ANDROID)
  441. if (encrypt_mode != 0 && encrypt_mode != -1) {
  442. DOMI_LOGE("encrypt mode must be 0 or -1");
  443. return false;
  444. }
  445. #else
  446. if (encrypt_mode != -1) {
  447. DOMI_LOGE("encrypt mode must be -1");
  448. return false;
  449. }
  450. #endif
  451. return true;
  452. }
  453. static Status CheckFrameWorkValid(int framework, const std::string weight_file) {
  454. if (framework != (int32_t)domi::CAFFE && framework != (int32_t)domi::TENSORFLOW &&
  455. framework != (int32_t)domi::MINDSPORE && framework != (int32_t)domi::ONNX) {
  456. // No framework information was passed in or the entered framework is illegal
  457. ErrorManager::GetInstance().ATCReportErrMessage(
  458. "E10007", {"parameter", "support"},
  459. {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)"});
  460. DOMI_LOGE("Input parameter[--framework] is mandatory and it's value must be: "
  461. "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx).");
  462. return domi::PARAM_INVALID;
  463. }
  464. if ((framework == (int32_t)domi::CAFFE) && (weight_file == "")) {
  465. ErrorManager::GetInstance().ATCReportErrMessage("E10008", {"parameter"}, {"weight"});
  466. DOMI_LOGE("Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!");
  467. return domi::PARAM_INVALID;
  468. }
  469. if ((framework == (int32_t)domi::TENSORFLOW) && (weight_file != "")) {
  470. GELOGW("Parameter weight is ignored for TensorFlow.");
  471. }
  472. if ((framework == (int32_t)domi::ONNX) && (weight_file != "")) {
  473. GELOGW("Parameter weight is ignored for Onnx.");
  474. }
  475. return domi::SUCCESS;
  476. }
  477. static bool CheckPathWithName(const std::string &fileName) {
  478. // Determine file path length
  479. if (fileName.size() > static_cast<int>(PATH_MAX)) {
  480. ErrorManager::GetInstance().ATCReportErrMessage(
  481. "E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)});
  482. GELOGE(ge::FAILED, "Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX);
  483. return false;
  484. }
  485. // Find the last separator
  486. int slashPosition = fileName.size() - 1;
  487. for (; slashPosition >= 0; slashPosition--) {
  488. if (fileName[slashPosition] == '\\' || fileName[slashPosition] == '/') {
  489. break;
  490. }
  491. }
  492. // Failure if no filename follows the path
  493. if (slashPosition == static_cast<int>(fileName.size() - 1)) {
  494. ErrorManager::GetInstance().ATCReportErrMessage("E10022", {"parameter", "filename"}, {"output", fileName});
  495. DOMI_LOGE("Input parameter[--output]'s path[%s] not include file name", fileName.c_str());
  496. return false;
  497. }
  498. return true;
  499. }
  500. };
  501. void SetDynamicInputSizeOptions() {
  502. if (!FLAGS_dynamic_batch_size.empty()) {
  503. domi::GetContext().dynamic_batch_size = FLAGS_dynamic_batch_size;
  504. }
  505. if (!FLAGS_dynamic_image_size.empty()) {
  506. domi::GetContext().dynamic_image_size = FLAGS_dynamic_image_size;
  507. }
  508. if (!FLAGS_dynamic_dims.empty()) {
  509. domi::GetContext().dynamic_dims = FLAGS_dynamic_dims;
  510. }
  511. }
  512. /// Validate the non-general scope fusion pass.
  513. /// The parameter is set to the name of the fusion rule.
  514. /// Multiple names can be set and separated by ",".
  515. void SetEnableScopeFusionPasses(const std::string pass_names) {
  516. ge::GetParserContext().enable_scope_fusion_passes = pass_names;
  517. }
  518. static bool CheckInputFormat() {
  519. if (FLAGS_input_format.empty()) {
  520. // Set default format
  521. if (FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW)) {
  522. FLAGS_input_format = "NHWC";
  523. } else {
  524. FLAGS_input_format = "NCHW";
  525. }
  526. return true;
  527. } else if ((FLAGS_framework == static_cast<int32_t>(domi::CAFFE))) { // caffe
  528. if (ge::caffe_support_input_format.find(FLAGS_input_format) != ge::caffe_support_input_format.end()) {
  529. return true;
  530. }
  531. // only support NCHW ND
  532. ErrorManager::GetInstance().ATCReportErrMessage(
  533. "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport});
  534. GELOGE(ge::FAILED,
  535. "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kCaffeFormatSupport);
  536. return false;
  537. } else if ((FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW))) { // tf
  538. if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) {
  539. return true;
  540. }
  541. // only support NCHW NHWC ND NCDHW NDHWC
  542. ErrorManager::GetInstance().ATCReportErrMessage(
  543. "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport});
  544. GELOGE(ge::FAILED,
  545. "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kTFFormatSupport);
  546. return false;
  547. } else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) {
  548. if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) {
  549. return true;
  550. }
  551. // only support NCHW ND
  552. ErrorManager::GetInstance().ATCReportErrMessage(
  553. "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport});
  554. GELOGE(ge::FAILED,
  555. "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), kONNXFormatSupport);
  556. return false;
  557. }
  558. return true;
  559. }
  560. #if !defined(__ANDROID__) && !defined(ANDROID)
  561. static void GetCustomOpPath(std::string &customop_path) {
  562. GELOGI("Enter get custom op path schedule");
  563. std::string fmk_type = ge::TypeUtils::FmkTypeToSerialString(static_cast<domi::FrameworkType>(FLAGS_framework));
  564. GELOGI("Framework type is %s.", fmk_type.c_str());
  565. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  566. if (path_env != nullptr) {
  567. std::string path = path_env;
  568. customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type);
  569. GELOGI("Get custom so path from env : %s", path_env);
  570. return;
  571. }
  572. std::string path_base = ge::GELib::GetPath();
  573. GELOGI("path_base is %s", path_base.c_str());
  574. path_base = path_base.substr(0, path_base.rfind('/'));
  575. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  576. customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type);
  577. return;
  578. }
  579. void GetPluginSoFileList(const string &path, vector<string> &fileList, string &caffe_parser_path) {
  580. // Support to split multiple so directories by ":"
  581. GELOGI("path is %s", path.c_str());
  582. vector<string> v_path = ge::StringUtils::Split(path, ':');
  583. for (size_t i = 0; i < v_path.size(); ++i) {
  584. ge::FindParserSo(v_path[i], fileList, caffe_parser_path);
  585. GELOGI("CustomOpLib full name = %s", v_path[i].c_str());
  586. }
  587. }
  588. void LoadModelParserLib(std::string caffe_parser_path) {
  589. if (FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW)) {
  590. void *tf_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL);
  591. if (tf_handle == nullptr) {
  592. GELOGW("dlopen fmk library [libfmk_parser.so] failed.");
  593. return;
  594. }
  595. GELOGI("plugin load libfmk_parser.so success.");
  596. } else if (FLAGS_framework == static_cast<int32_t>(domi::CAFFE)) {
  597. // What we are dealing with here is that the user modifies the caffe.proto scenario.
  598. // If no lib_Caffe_Parser.so is found under the plugin path, use the default lib_Caffe_Parser.so path.
  599. caffe_parser_path = caffe_parser_path.empty() ? "lib_caffe_parser.so" : caffe_parser_path;
  600. void *handle = dlopen(caffe_parser_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
  601. if (handle == nullptr) {
  602. GELOGW("dlopen failed, plugin name:%s. Message(%s).", caffe_parser_path.c_str(), dlerror());
  603. return;
  604. }
  605. GELOGI("plugin load %s success.", caffe_parser_path.c_str());
  606. // According to the dependency, the Caffe parsing module of the framework is loaded here( libfmk_parser.so).
  607. // (depend on the lib_caffe_parser.so)
  608. void *fmk_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL);
  609. if (fmk_handle == nullptr) {
  610. GELOGW("dlopen fmk library [libfmk_parser.so] failed.");
  611. if (dlclose(handle) != 0) {
  612. GELOGW("dlclose lib_caffe_parser.so failed.");
  613. }
  614. return;
  615. }
  616. GELOGI("plugin load libfmk_parser.so success.");
  617. } else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) {
  618. void *handle = dlopen("libfmk_onnx_parser.so", RTLD_NOW | RTLD_GLOBAL);
  619. if (handle == nullptr) {
  620. GELOGW("dlopen fmk library [libfmk_onnx_parser.so] failed.");
  621. return;
  622. }
  623. GELOGI("plugin load libfmk_onnx_parser.so success.");
  624. } else {
  625. GELOGW("Framework:%s is not support.",
  626. ge::TypeUtils::FmkTypeToSerialString(static_cast<domi::FrameworkType>(FLAGS_framework)).c_str());
  627. return;
  628. }
  629. return;
  630. }
  631. void LoadCustomOpLib(bool need_load_ops_plugin) {
  632. std::string plugin_path;
  633. GetCustomOpPath(plugin_path);
  634. vector<string> fileList;
  635. string caffe_parser_path = "";
  636. // whether there are files in the plugin so path
  637. GetPluginSoFileList(plugin_path, fileList, caffe_parser_path);
  638. // no file
  639. if (fileList.empty() && caffe_parser_path.empty()) {
  640. GELOGW("can not find any plugin file in plugin_path: %s", plugin_path.c_str());
  641. }
  642. LoadModelParserLib(caffe_parser_path);
  643. if (!need_load_ops_plugin) {
  644. GELOGI("No need to load ops plugin so.");
  645. return;
  646. }
  647. OpRegistry::Instance()->registrationDatas.clear();
  648. // load other so files except lib_caffe_parser.so in the plugin so path
  649. for (auto elem : fileList) {
  650. ge::StringUtils::Trim(elem);
  651. void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL);
  652. if (handle == nullptr) {
  653. GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror());
  654. } else {
  655. GELOGI("plugin load %s success.", elem.c_str());
  656. }
  657. }
  658. std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas;
  659. for (OpRegistrationData reg_data : registrationDatas) {
  660. if (reg_data.GetFrameworkType() == static_cast<domi::FrameworkType>(FLAGS_framework)) {
  661. (void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data);
  662. (void)OpRegistry::Instance()->Register(reg_data);
  663. }
  664. }
  665. }
  666. void SaveCustomCaffeProtoPath() {
  667. GELOGI("Enter save custom caffe proto path.");
  668. std::string path_base = ge::GELib::GetPath();
  669. GELOGI("path_base is %s", path_base.c_str());
  670. path_base = path_base.substr(0, path_base.rfind('/'));
  671. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  672. ge::GetParserContext().caffe_proto_path = path_base + "include/proto/";
  673. string customop_path;
  674. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  675. if (path_env != nullptr) {
  676. std::string path = path_env;
  677. customop_path = path + "/framework/custom/caffe/";
  678. GELOGI("Get custom proto path from env : %s", path_env);
  679. ge::GetParserContext().custom_proto_path = customop_path;
  680. return;
  681. }
  682. customop_path = path_base + "ops/framework/custom/caffe/";
  683. ge::GetParserContext().custom_proto_path = customop_path;
  684. return;
  685. }
  686. #endif
  687. Status CreateInputsForInference(const ge::Graph &graph, vector<ge::GeTensor> &inputs) {
  688. auto compute_graph = ge::GraphUtils::GetComputeGraph(graph);
  689. GE_CHECK_NOTNULL(compute_graph);
  690. for (ge::NodePtr &input_node : compute_graph->GetAllNodes()) {
  691. GE_CHECK_NOTNULL(input_node);
  692. ge::OpDescPtr op = input_node->GetOpDesc();
  693. GE_CHECK_NOTNULL(op);
  694. if (op->GetType() == ge::DATA) {
  695. GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size());
  696. ge::GeTensorDesc tensor = op->GetInputDesc(0);
  697. string data_op_name = op->GetName();
  698. GELOGI("Data op name is: %s", data_op_name.c_str());
  699. ge::GeShape data_shape;
  700. auto iter = domi::GetContext().input_dims.find(data_op_name);
  701. if (iter != domi::GetContext().input_dims.end()) {
  702. data_shape = ge::GeShape(iter->second);
  703. GELOGI("Data op get shape from Context.");
  704. } else {
  705. data_shape = tensor.GetShape();
  706. GELOGI("Data op get shape from InputDesc in geir graph.");
  707. }
  708. ge::DataType data_type = tensor.GetDataType();
  709. string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type);
  710. GELOGI("Data op get data type:%s from InputDesc in geir graph.", data_type_str.c_str());
  711. ge::GeTensor input_tensor;
  712. ge::GeTensorDesc desc(data_shape, ge::Format(domi::GetContext().format), data_type);
  713. input_tensor.SetTensorDesc(desc);
  714. inputs.push_back(input_tensor);
  715. }
  716. }
  717. GELOGI("Build ME model, inputs size is: %zu", inputs.size());
  718. return ge::SUCCESS;
  719. }
  720. domi::Status GenerateInfershapeJson() {
  721. if (!CheckInputFormat()) {
  722. GELOGE(ge::FAILED, "Check input_format failed");
  723. return domi::FAILED;
  724. }
  725. Status ret = GFlagUtils::CheckDumpInfershapeJsonFlags();
  726. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check flags failed!");
  727. ge::GeGenerator ge_generator;
  728. std::map<string, string> options;
  729. ge::Status geRet = ge_generator.Initialize(options, domi::GetContext());
  730. if (geRet != ge::SUCCESS) {
  731. DOMI_LOGE("GeGenerator initialize failed!");
  732. return domi::FAILED;
  733. }
  734. ge::Graph graph;
  735. std::map<string, string> atc_params;
  736. atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format));
  737. ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework,
  738. "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false);
  739. if (ret != ge::SUCCESS) {
  740. DOMI_LOGE("ATC Parse graph domi::FAILED");
  741. (void)ge_generator.Finalize();
  742. return domi::FAILED;
  743. }
  744. geRet = ge_generator.GenerateInfershapeGraph(graph);
  745. if (geRet != ge::SUCCESS) {
  746. DOMI_LOGE("ATC GenerateInfershapeJson failed");
  747. (void)ge_generator.Finalize();
  748. return domi::FAILED;
  749. }
  750. if (DumpInfershapeJson(graph, FLAGS_json.c_str()) != SUCCESS) {
  751. DOMI_LOGE("ATC DumpInfershapeJson failed");
  752. (void)ge_generator.Finalize();
  753. return domi::FAILED;
  754. }
  755. (void)ge_generator.Finalize();
  756. return ge::SUCCESS;
  757. }
  758. static Status ConvertModelToJson(int fwk_type, const string &model_file, const string &json_file) {
  759. Status ret = ge::SUCCESS;
  760. if (fwk_type == -1) {
  761. ret = ge::ConvertOm(model_file.c_str(), json_file.c_str(), true);
  762. return ret;
  763. }
  764. if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE) && (fwk_type != domi::ONNX)) {
  765. ErrorManager::GetInstance().ATCReportErrMessage(
  766. "E10001", {"parameter", "value", "reason"},
  767. {"--framework", std::to_string(fwk_type), kModelToJsonSupport});
  768. GELOGE(ge::FAILED, "Invalid value for --framework[%d], %s.", fwk_type, kModelToJsonSupport);
  769. ret = ge::FAILED;
  770. }
  771. if (FLAGS_dump_mode != "0" && FLAGS_dump_mode != "1") {
  772. ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"dump_mode"});
  773. GELOGE(ge::FAILED, "Input parameter[--dump_mode]'s value must be 1 or 0.");
  774. ret = ge::FAILED;
  775. }
  776. if (ret != ge::SUCCESS) return ret;
  777. // Need to save caffe.proto path
  778. SaveCustomCaffeProtoPath();
  779. if (FLAGS_dump_mode == "0") {
  780. // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so.
  781. LoadCustomOpLib(false);
  782. ret = ge::ConvertFwkModelToJson((domi::FrameworkType)fwk_type, model_file.c_str(), json_file.c_str());
  783. } else if (FLAGS_dump_mode == "1") {
  784. // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so and ops plugin so.
  785. LoadCustomOpLib(true);
  786. ret = GenerateInfershapeJson();
  787. }
  788. return ret;
  789. }
  790. domi::Status GenerateModel(std::map<string, string> &options, std::string output) {
  791. ge::GeGenerator ge_generator;
  792. ge::Status geRet = ge::SUCCESS;
  793. std::shared_ptr<ge::GELib> instance_ptr = ge::GELib::GetInstance();
  794. if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
  795. geRet = ge::GELib::Initialize(options);
  796. if (geRet != ge::SUCCESS) {
  797. DOMI_LOGE("GE initialize failed!");
  798. return domi::FAILED;
  799. }
  800. }
  801. geRet = ge_generator.Initialize(options, domi::GetContext());
  802. if (geRet != ge::SUCCESS) {
  803. DOMI_LOGE("GeGenerator initialize failed!");
  804. (void)ge::GELib::GetInstance()->Finalize();
  805. return domi::FAILED;
  806. }
  807. ge::Graph graph;
  808. std::vector<ge::GeTensor> inputs;
  809. if (FLAGS_framework == domi::MINDSPORE) {
  810. // load model from file
  811. ge::Model load_model = ge::Model("loadmodel", "version2");
  812. auto ret1 = load_model.LoadFromFile(FLAGS_model);
  813. if (ret1 != ge::GRAPH_SUCCESS) {
  814. ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model});
  815. DOMI_LOGE("Load model from %s failed, please check model file or "
  816. "input parameter[--framework] is correct", FLAGS_model.c_str());
  817. (void)ge_generator.Finalize();
  818. (void)ge::GELib::GetInstance()->Finalize();
  819. return domi::FAILED;
  820. }
  821. graph = load_model.GetGraph();
  822. GE_CHK_STATUS_EXEC(ge::InitDomiOmgContext(FLAGS_input_shape, FLAGS_input_format, "", is_dynamic_input),
  823. GELOGE(ge::FAILED, "ATC Generate call InitDomiOmgContext ret fail");
  824. (void)ge_generator.Finalize(); (void)ge::GELib::GetInstance()->Finalize(); return domi::FAILED);
  825. Status ret = CreateInputsForInference(graph, inputs);
  826. if (ret != ge::SUCCESS) {
  827. GELOGE(ge::FAILED, "create inputs for inference failed.");
  828. (void)ge_generator.Finalize();
  829. (void)ge::GELib::GetInstance()->Finalize();
  830. return domi::FAILED;
  831. }
  832. } else {
  833. std::map<string, string> atc_params;
  834. atc_params.insert(std::pair<string, string>("input_shape", FLAGS_input_shape));
  835. atc_params.insert(std::pair<string, string>("out_nodes", FLAGS_out_nodes));
  836. atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format));
  837. atc_params.insert(std::pair<string, string>("check_report", FLAGS_check_report));
  838. atc_params.insert(std::pair<string, string>("input_fp16_nodes", FLAGS_input_fp16_nodes));
  839. atc_params.insert(std::pair<string, string>("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout));
  840. atc_params.insert(std::pair<string, string>("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout));
  841. atc_params.insert(std::pair<string, string>("compress_weight_conf", FLAGS_compress_weight_conf));
  842. atc_params.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type));
  843. atc_params.insert(std::pair<string, string>("output", output));
  844. Status ret =
  845. ParseGraph(graph, atc_params, FLAGS_model.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType)FLAGS_framework,
  846. FLAGS_op_name_map.c_str(), FLAGS_target.c_str(), (ge::RunMode)FLAGS_mode, is_dynamic_input);
  847. // in ONLY_PRE_CHECK mode, pre-checking report has already saved in ParseGraph
  848. if (FLAGS_mode == ge::ONLY_PRE_CHECK) {
  849. (void)ge_generator.Finalize();
  850. (void)ge::GELib::GetInstance()->Finalize();
  851. if (ret != ge::SUCCESS) {
  852. DOMI_LOGE("ATC precheck fail.");
  853. return domi::FAILED;
  854. }
  855. return domi::SUCCESS;
  856. }
  857. if (ret != ge::SUCCESS) {
  858. DOMI_LOGE("ATC Parse graph domi::FAILED");
  859. DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case
  860. (void)ge_generator.Finalize();
  861. (void)ge::GELib::GetInstance()->Finalize();
  862. return domi::FAILED;
  863. }
  864. if (ge::SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) {
  865. DOMI_LOGE("Set output node info fail.");
  866. (void)ge_generator.Finalize();
  867. (void)ge::GELib::GetInstance()->Finalize();
  868. return domi::FAILED;
  869. }
  870. }
  871. Status ret = ge::DealKeepDtypeOption(ge::GraphUtils::GetComputeGraph(graph), FLAGS_keep_dtype);
  872. if (ret != SUCCESS) {
  873. (void)ge_generator.Finalize();
  874. (void)ge::GELib::GetInstance()->Finalize();
  875. return ret;
  876. }
  877. geRet = ge_generator.GenerateOfflineModel(graph, output, inputs);
  878. if (geRet != ge::SUCCESS) {
  879. DOMI_LOGE("GE GenerateOfflineModel execute failed");
  880. DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case
  881. // checking error log)
  882. (void)ge_generator.Finalize();
  883. (void)ge::GELib::GetInstance()->Finalize();
  884. return domi::FAILED;
  885. }
  886. (void)ge_generator.Finalize();
  887. (void)ge::GELib::GetInstance()->Finalize();
  888. return ge::SUCCESS;
  889. }
  890. static void SetEnvForSingleOp(std::map<string, string> &options) {
  891. string flag_on = "1";
  892. string flag_off = "0";
  893. options.emplace(ge::GE_FE_FLAG, flag_on);
  894. options.emplace(ge::STREAM_NUM, "1"); // single op only use one stream
  895. options.emplace(ge::RUN_FLAG, flag_off);
  896. options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off);
  897. options.emplace(ge::SINGLE_OP_FLAG, flag_on);
  898. options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode);
  899. options.emplace(ge::SOC_VERSION, FLAGS_soc_version);
  900. options.emplace(ge::CORE_TYPE, FLAGS_core_type);
  901. options.emplace(ge::AICORE_NUM, FLAGS_aicore_num);
  902. options.emplace(ge::OP_SELECT_IMPL_MODE, FLAGS_op_select_implmode);
  903. options.emplace(ge::OPTYPELIST_FOR_IMPLMODE, FLAGS_optypelist_for_implmode);
  904. options.emplace(ge::AUTO_TUNE_MODE, FLAGS_auto_tune_mode);
  905. options.emplace(ge::OP_DEBUG_LEVEL, to_string(FLAGS_op_debug_level));
  906. options.emplace(ge::DEBUG_DIR, FLAGS_debug_dir);
  907. options.emplace(ge::OP_COMPILER_CACHE_DIR, FLAGS_op_compiler_cache_dir);
  908. options.emplace(ge::OP_COMPILER_CACHE_MODE, FLAGS_op_compiler_cache_mode);
  909. options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path);
  910. options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path);
  911. }
  912. domi::Status GenerateSingleOp(const std::string& json_file_path) {
  913. if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) {
  914. DOMI_LOGE("output path %s is not valid!", FLAGS_output.c_str());
  915. return domi::FAILED;
  916. }
  917. // check optypelist_for_implmode and op_select_implmode
  918. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  919. ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS,
  920. return ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!");
  921. std::map<string, string> options;
  922. // need to be changed when ge.ini plan is done
  923. SetEnvForSingleOp(options);
  924. auto ret = ge::GELib::Initialize(options);
  925. if (ret != ge::SUCCESS) {
  926. DOMI_LOGE("GE initialize failed!");
  927. return domi::FAILED;
  928. }
  929. ge::GeGenerator generator;
  930. ret = generator.Initialize(options, domi::GetContext());
  931. if (ret != SUCCESS) {
  932. DOMI_LOGE("GeGenerator initialize failed!");
  933. (void)ge::GELib::GetInstance()->Finalize();
  934. return domi::FAILED;
  935. }
  936. vector<ge::SingleOpBuildParam> build_params;
  937. if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) {
  938. DOMI_LOGE("parse single op json file failed");
  939. (void)generator.Finalize();
  940. (void)ge::GELib::GetInstance()->Finalize();
  941. return domi::FAILED;
  942. }
  943. int index = 0;
  944. for (auto &param : build_params) {
  945. string output_path;
  946. if (!FLAGS_output.empty()) {
  947. output_path = FLAGS_output + "/";
  948. }
  949. output_path += param.file_name;
  950. ret = generator.BuildSingleOpModel(param.op_desc, param.inputs, param.outputs, output_path);
  951. if (ret != SUCCESS) {
  952. DOMI_LOGE("Compile op failed. ge ret = %u, op index = %d", ret, index);
  953. ret = domi::FAILED;
  954. break;
  955. }
  956. GELOGI("Compile op success. op index = %d, output = %s", index, output_path.c_str());
  957. index += 1;
  958. }
  959. (void)generator.Finalize();
  960. (void)ge::GELib::GetInstance()->Finalize();
  961. return ret;
  962. }
  963. domi::Status GenerateOmModel() {
  964. if (!CheckInputFormat()) {
  965. GELOGE(ge::FAILED, "Check input_format failed");
  966. return domi::FAILED;
  967. }
  968. Status ret = GFlagUtils::CheckFlags();
  969. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED,
  970. "Check flags failed! Please check whether some atc params that include semicolons[;] use double "
  971. "quotation marks (\") to enclose each argument such as out_nodes, input_shape, dynamic_image_size");
  972. #if !defined(__ANDROID__) && !defined(ANDROID)
  973. // Load custom operator Library
  974. LoadCustomOpLib(true);
  975. SaveCustomCaffeProtoPath();
  976. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "check custom aicpu run so failed!");
  977. #endif
  978. const int f_stream_num = 1;
  979. std::map<string, string> options;
  980. options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(FLAGS_framework)));
  981. options.insert(std::pair<string, string>(string(ge::STREAM_NUM), to_string(f_stream_num)));
  982. options.insert(std::pair<string, string>(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf));
  983. options.insert(std::pair<string, string>(string(ge::ENCRYPT_MODE), to_string(FLAGS_encrypt_mode)));
  984. options.insert(std::pair<string, string>(string(ge::EK_FILE), FLAGS_encrypt_key));
  985. options.insert(std::pair<string, string>(string(ge::CERT_FILE), FLAGS_certificate));
  986. options.insert(std::pair<string, string>(string(ge::HW_KEY_FILE), FLAGS_hardware_key));
  987. options.insert(std::pair<string, string>(string(ge::PRIVATE_KEY_FILE), FLAGS_private_key));
  988. options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes));
  989. options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf));
  990. options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode));
  991. options.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0)));
  992. options.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0)));
  993. if (!FLAGS_output_type.empty()) {
  994. options.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type));
  995. }
  996. options.insert(std::pair<string, string>(string(ge::OP_SELECT_IMPL_MODE), FLAGS_op_select_implmode));
  997. options.insert(std::pair<string, string>(string(ge::OPTYPELIST_FOR_IMPLMODE), FLAGS_optypelist_for_implmode));
  998. if (!FLAGS_input_fp16_nodes.empty()) {
  999. GELOGI("FLAGS_input_fp16_nodes : %s .", FLAGS_input_fp16_nodes.c_str());
  1000. options.insert(std::pair<string, string>(ge::INPUT_FP16_NODES, FLAGS_input_fp16_nodes));
  1001. }
  1002. options.insert(std::pair<string, string>(string(ge::AUTO_TUNE_MODE), FLAGS_auto_tune_mode));
  1003. options.insert(
  1004. std::pair<string, string>(string(ge::OPTION_EXEC_DISABLE_REUSED_MEMORY), to_string(FLAGS_disable_reuse_memory)));
  1005. options.insert(std::pair<string, string>(string(ge::SOC_VERSION), FLAGS_soc_version));
  1006. options.insert(std::pair<string, string>(string(ge::CORE_TYPE), FLAGS_core_type));
  1007. options.insert(std::pair<string, string>(string(ge::AICORE_NUM), FLAGS_aicore_num));
  1008. options.insert(std::pair<string, string>(string(ge::BUFFER_OPTIMIZE), FLAGS_buffer_optimize));
  1009. options.insert(std::pair<string, string>(string(ge::ENABLE_SMALL_CHANNEL), FLAGS_enable_small_channel));
  1010. options.insert(std::pair<string, string>(string(ge::FUSION_SWITCH_FILE), FLAGS_fusion_switch_file));
  1011. options.insert(std::pair<string, string>(string(ge::ENABLE_COMPRESS_WEIGHT),
  1012. (FLAGS_enable_compress_weight == "true") ?
  1013. ge::kEnableCompressWeightTrue : ge::kEnableCompressWeightFalse));
  1014. options.insert(std::pair<string, string>(string(ge::ENABLE_SINGLE_STREAM), FLAGS_enable_single_stream));
  1015. options.insert(std::pair<string, string>(string(ge::DEBUG_DIR), FLAGS_debug_dir));
  1016. options.insert(std::pair<string, string>(string(ge::OP_COMPILER_CACHE_DIR), FLAGS_op_compiler_cache_dir));
  1017. options.insert(std::pair<string, string>(string(ge::OP_COMPILER_CACHE_MODE), FLAGS_op_compiler_cache_mode));
  1018. SetDynamicInputSizeOptions();
  1019. if (!FLAGS_save_original_model.empty()) {
  1020. options.insert(std::pair<string, string>(string(ge::SAVE_ORIGINAL_MODEL), FLAGS_save_original_model));
  1021. options.insert(std::pair<string, string>(string(ge::ORIGINAL_MODEL_FILE), FLAGS_output + "_original.om"));
  1022. }
  1023. options.insert(std::pair<string, string>(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level)));
  1024. options.insert(std::pair<string, string>(string(ge::MDL_BANK_PATH_FLAG), FLAGS_mdl_bank_path));
  1025. options.insert(std::pair<string, string>(string(ge::OP_BANK_PATH_FLAG), FLAGS_op_bank_path));
  1026. options.insert(std::pair<string, string>(string(ge::DISPLAY_MODEL_INFO), FLAGS_display_model_info));
  1027. // set enable scope fusion passes
  1028. SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes);
  1029. // print atc option map
  1030. ge::PrintOptionMap(options, "atc option");
  1031. // When the ATC module is transferred to a model, the suffix ".om" is automatically added to the model name
  1032. FLAGS_output = FLAGS_output + ".om";
  1033. ret = GenerateModel(options, FLAGS_output);
  1034. if (ret != domi::SUCCESS) {
  1035. return domi::FAILED;
  1036. }
  1037. if (FLAGS_display_model_info == "1") {
  1038. GELOGI("need to display model info.");
  1039. return ge::ConvertOm(FLAGS_output.c_str(), "", false);
  1040. }
  1041. return domi::SUCCESS;
  1042. }
  1043. domi::Status ConvertModelToJson() {
  1044. Status ret = GFlagUtils::CheckConverJsonParamFlags();
  1045. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check convert json params flags failed!");
  1046. ret = ConvertModelToJson(FLAGS_framework, FLAGS_om, FLAGS_json);
  1047. GE_IF_BOOL_EXEC(ret != domi::SUCCESS, return domi::FAILED);
  1048. return domi::SUCCESS;
  1049. }
  1050. domi::Status DisplayModelInfo() {
  1051. // No model path passed in
  1052. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "",
  1053. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"});
  1054. return ge::FAILED,
  1055. "Input parameter[--om]'s value is empty!!");
  1056. // Check if the model path is valid
  1057. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  1058. FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"),
  1059. return ge::FAILED,
  1060. "model file path is invalid: %s.", FLAGS_om.c_str());
  1061. if (FLAGS_framework == -1) {
  1062. return ge::ConvertOm(FLAGS_om.c_str(), "", false);
  1063. }
  1064. return ge::FAILED;
  1065. }
  1066. bool CheckRet(domi::Status ret) {
  1067. if (ret != domi::SUCCESS) {
  1068. if (FLAGS_mode == ONLY_PRE_CHECK) {
  1069. GELOGW("ATC precheck failed.");
  1070. } else if (FLAGS_mode == GEN_OM_MODEL) {
  1071. GELOGW("ATC generate offline model failed.");
  1072. } else if (FLAGS_mode == MODEL_TO_JSON) {
  1073. GELOGW("ATC convert model to json file failed.");
  1074. } else if (FLAGS_mode == PBTXT_TO_JSON) {
  1075. GELOGW("ATC convert pbtxt to json file failed.");
  1076. } else {
  1077. return false;
  1078. }
  1079. return false;
  1080. }
  1081. if (FLAGS_mode == ONLY_PRE_CHECK) {
  1082. GELOGI("ATC precheck success.");
  1083. } else if (FLAGS_mode == GEN_OM_MODEL) {
  1084. GELOGI("ATC generate offline model success.");
  1085. } else if (FLAGS_mode == MODEL_TO_JSON) {
  1086. GELOGI("ATC convert model to json file success.");
  1087. } else if (FLAGS_mode == PBTXT_TO_JSON) {
  1088. GELOGI("ATC convert pbtxt to json file success.");
  1089. }
  1090. return true;
  1091. }
  1092. domi::Status ConvertPbtxtToJson() {
  1093. Status ret = GFlagUtils::CheckConverJsonParamFlags();
  1094. if (ret != domi::SUCCESS) {
  1095. GELOGE(ge::FAILED, "Check convert json params flags failed!");
  1096. return domi::FAILED;
  1097. }
  1098. ret = ge::ConvertPbtxtToJson(FLAGS_om.c_str(), FLAGS_json.c_str());
  1099. if (ret != domi::SUCCESS) {
  1100. GELOGE(ge::FAILED, "ConvertPbtxtToJson fail.");
  1101. return domi::FAILED;
  1102. }
  1103. return domi::SUCCESS;
  1104. }
  1105. int init(int argc, char* argv[]) {
  1106. GFlagUtils::InitGFlag(argc, argv);
  1107. // set log level
  1108. int ret = -1;
  1109. const std::set<string> log_level = {"null", "debug", "info", "warning", "error"};
  1110. if (log_level.count(FLAGS_log) == 0) {
  1111. std::cout << "E10010: invalid value for --log:" << FLAGS_log
  1112. <<", only support debug, info, warning, error, null"<< std::endl;
  1113. return ret;
  1114. }
  1115. ret = ge::CheckLogParamValidAndSetLogLevel(FLAGS_log);
  1116. if (ret != 0) {
  1117. return ret;
  1118. }
  1119. std::string path_base = ge::GELib::GetPath();
  1120. ret = ErrorManager::GetInstance().Init(path_base);
  1121. if (ret != 0) {
  1122. DOMI_LOGE("ErrorManager init fail !");
  1123. return ret;
  1124. }
  1125. return 0;
  1126. }
  1127. long GetMemInfo(const std::string &key) {
  1128. std::string file_path = "/proc/meminfo";
  1129. std::ifstream fs(file_path, std::ifstream::in);
  1130. if (!fs.is_open()) {
  1131. GELOGW("Can not open %s .", file_path.c_str());
  1132. return 0;
  1133. }
  1134. std::string line;
  1135. while (getline(fs, line)) { // line not with \n
  1136. if (line.find(key) != std::string::npos) {
  1137. GELOGI("Find mem [%s] info line [%s]", key.c_str(), line.c_str());
  1138. fs.close();
  1139. size_t pos = line.find(":");
  1140. if (pos == std::string::npos) {
  1141. return 0;
  1142. }
  1143. std::string current_mem_info_str = line.substr(pos + 1);
  1144. ge::StringUtils::Trim(current_mem_info_str);
  1145. GELOGI("Find mem [%s] info [%s].", key.c_str(), current_mem_info_str.c_str());
  1146. return stol(current_mem_info_str);
  1147. }
  1148. }
  1149. fs.close(); // close the file
  1150. return 0;
  1151. }
  1152. bool CheckMemInfo() {
  1153. if (FLAGS_auto_tune_mode.empty()) {
  1154. return true;
  1155. }
  1156. // only check current available mem when auto_tune_mode is set.
  1157. long current_mem_available = GetMemInfo("MemAvailable");
  1158. GELOGI("Get mem available [%lu].", current_mem_available);
  1159. std::cout << "Current available mem is " << current_mem_available << "kB." << std::endl;
  1160. if ((current_mem_available > 0) && (current_mem_available < kMinAvailableMem)) {
  1161. GELOGE(ge::PARAM_INVALID, "Current available mem [%lu] can not be smaller than [%lu] .",
  1162. current_mem_available, kMinAvailableMem);
  1163. ErrorManager::GetInstance().ATCReportErrMessage("E10044", {"value", "min_value"},
  1164. {to_string(current_mem_available), to_string(kMinAvailableMem)});
  1165. return false;
  1166. }
  1167. return true;
  1168. }
  1169. int main(int argc, char* argv[]) {
  1170. Status ret = domi::SUCCESS;
  1171. std::cout << "ATC start working now, please wait for a moment." << std::endl;
  1172. // Initialize
  1173. if (init(argc, argv) != 0) {
  1174. std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl;
  1175. return -1;
  1176. }
  1177. do {
  1178. if (!CheckMemInfo()) {
  1179. GELOGE(ge::PARAM_INVALID, "Current available mem is too small");
  1180. ret = domi::FAILED;
  1181. break;
  1182. }
  1183. if (!FLAGS_singleop.empty()) {
  1184. ret = GenerateSingleOp(FLAGS_singleop);
  1185. break;
  1186. }
  1187. // default mode(mode:0), Open source model to model
  1188. if (GEN_OM_MODEL == FLAGS_mode || ONLY_PRE_CHECK == FLAGS_mode) {
  1189. GE_IF_BOOL_EXEC(GenerateOmModel() != domi::SUCCESS, ret = domi::FAILED; break);
  1190. } else if (MODEL_TO_JSON == FLAGS_mode) { // Mode 1, transfer model to JSON
  1191. GE_CHK_BOOL_EXEC(ConvertModelToJson() == domi::SUCCESS, ret = domi::FAILED;
  1192. break, "ATC ConvertJson execute failed!!");
  1193. } else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) {
  1194. GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED;
  1195. break, "ATC convert pbtxt to json execute failed!!");
  1196. } else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) {
  1197. GE_CHK_BOOL_EXEC(DisplayModelInfo() == domi::SUCCESS, ret = domi::FAILED;
  1198. break, "ATC DisplayModelInfo failed!!");
  1199. } else {
  1200. ErrorManager::GetInstance().ATCReportErrMessage(
  1201. "E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport});
  1202. GELOGE(ge::PARAM_INVALID, "Invalid value for --mode[%d], %s.", FLAGS_mode, kModeSupport);
  1203. ret = domi::FAILED;
  1204. break;
  1205. }
  1206. } while (0);
  1207. if (!CheckRet(ret)) {
  1208. std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl;
  1209. int result = ErrorManager::GetInstance().OutputErrMessage(STDOUT_FILENO);
  1210. if (result != 0) {
  1211. DOMI_LOGE("ErrorManager outputErrMessage fail !");
  1212. }
  1213. GELOGI("Current mem available mem is [%lu]", GetMemInfo("MemAvailable"));
  1214. return ret;
  1215. } else {
  1216. std::cout << "ATC run success, welcome to the next use." << std::endl;
  1217. (void)ErrorManager::GetInstance().OutputMessage(STDOUT_FILENO);
  1218. return 0;
  1219. }
  1220. } /*lint +e530*/

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示