You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.cc 67 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <dirent.h>
  17. #include <dlfcn.h>
  18. #include <gflags/gflags.h>
  19. #include <sys/types.h>
  20. #include <unistd.h>
  21. #include <cctype>
  22. #include <climits>
  23. #include <cstdlib>
  24. #include <iostream>
  25. #include "framework/common/gflags_util.h"
  26. #include "framework/common/util.h"
  27. #include "common/util/error_manager/error_manager.h"
  28. #include "framework/common/debug/ge_log.h"
  29. #include "external/ge/ge_api.h"
  30. #include "framework/generator/ge_generator.h"
  31. #include "graph/anchor.h"
  32. #include "graph/debug/ge_attr_define.h"
  33. #include "external/graph/graph.h"
  34. #include "graph/op_desc.h"
  35. #include "graph/utils/graph_utils.h"
  36. #include "graph/utils/type_utils.h"
  37. #include "init/gelib.h"
  38. #include "ir_build/option_utils.h"
  39. #include "framework/omg/omg.h"
  40. #include "framework/omg/parser/parser_factory.h"
  41. #include "framework/omg/parser/parser_inner_ctx.h"
  42. #include "parser/common/register_tbe.h"
  43. #include "register/op_registry.h"
  44. #include "offline/single_op_parser.h"
  45. #include "external/ge/ge_ir_build.h"
  46. using domi::BuildMode;
  47. using domi::OpRegistrationData;
  48. using domi::OpRegistry;
  49. using domi::Status;
  50. using domi::SUCCESS;
  51. using ge::GEN_OM_MODEL;
  52. using ge::GflagsUtils;
  53. using ge::MODEL_TO_JSON;
  54. using ge::ONLY_PRE_CHECK;
  55. using ge::ParseInputShape;
  56. using ge::PBTXT_TO_JSON;
  57. using std::map;
  58. using std::pair;
  59. using std::shared_ptr;
  60. using std::string;
  61. using std::vector;
  62. namespace {
  63. static bool is_dynamic_input = false;
  64. const char *const kModeSupport = "only support 0(model to framework model), "
  65. "1(framework model to json), 3(only pre-check), "
  66. "5(pbtxt to json), 6(display model info)";
  67. const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow) 5(Onnx) when model set 1";
  68. const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model";
  69. const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model";
  70. const char *const kONNXFormatSupport = "only support NCHW, ND, NCDHW in ONNX model";
  71. // limit available mem size 2G
  72. const long kMinAvailableMem = 2097152; // 2 * 1024 * 1024
  73. } // namespace
  74. DEFINE_string(model, "", "The model file.");
  75. DEFINE_string(output, "", "The output file path&name.");
  76. DEFINE_int32(framework, -1, "Framework type(0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx).");
  77. DEFINE_string(weight, "", "Optional; weight file. Required when framework is Caffe.");
  78. DEFINE_string(input_shape, "",
  79. "Optional; shape of input data. Required when framework is caffe "
  80. "or TensorFLow or MindSpore or Onnx. "
  81. "Format: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"");
  82. DEFINE_string(input_shape_range, "",
  83. "Optional; shape range of input data. Required when framework is caffe "
  84. "or TensorFLow or Onnx. "
  85. "Format: \"input_name1:[n1~n2,c1,h1,w1];input_name2:[n2~n3,c2,h2,w2]\"");
  86. DEFINE_bool(h, false, "show this help message");
  87. DEFINE_string(cal_conf, "", "Optional; the calibration config file.");
  88. DEFINE_string(insert_op_conf, "", "Optional; the config file to insert new op, for example AIPP op.");
  89. DEFINE_string(op_name_map, "", "Optional; custom op name mapping file.");
  90. DEFINE_string(target, "", "Optional; mini.");
  91. DEFINE_string(om, "", "The model file to be converted to json.");
  92. DEFINE_string(json, "", "The output json file path&name which is converted from a model.");
  93. DEFINE_int32(mode, 0,
  94. "Optional; run mode, 0(default): model => framework model; 1: "
  95. "framework model => json; 3: only pre-check; 5: txt => json.");
  96. DEFINE_string(out_nodes, "",
  97. "Optional; output nodes designated by users."
  98. "Format: \"node_name1:0;node_name1:1;node_name2:0\"");
  99. DEFINE_string(op_precision_mode, "", "Optional; operator precision mode configuration file path");
  100. DEFINE_string(precision_mode, "force_fp16",
  101. "Optional; precision mode."
  102. "Support force_fp16, force_fp32, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype.");
  103. DEFINE_string(modify_mixlist, "", "Optional; operator mixed precision configuration file path");
  104. DEFINE_string(keep_dtype, "",
  105. "Optional; config file to specify the precision used by the operator during compilation.");
  106. DEFINE_string(input_format, "",
  107. "Optional; input_format, format of input data, NCHW;NHWC."
  108. "Format:\"NHWC\"");
  109. DEFINE_string(check_report, "check_result.json", "Optional; the pre-checking report file.");
  110. DEFINE_string(input_fp16_nodes, "",
  111. "Optional; input node datatype is fp16 and format is NC1HWC0."
  112. "Format:\"node_name1;node_name2\"");
  113. DEFINE_string(is_output_adjust_hw_layout, "",
  114. "Optional; Net output node's datatype is fp16 and format is "
  115. "NC1HWC0, or not."
  116. "Format:\"false,true,false,true\"");
  117. DEFINE_string(is_input_adjust_hw_layout, "",
  118. "Optional; Intput node's datatype is fp16 and format is "
  119. "NC1HWC0, or not."
  120. "Format:\"false,true,false,true\"");
  121. DEFINE_string(output_type, "",
  122. "Optional; output type! "
  123. "Support FP32,FP16,INT8,INT16,UINT16,UINT8,INT32,INT64,UINT32,UINT64,DOUBLE.");
  124. DEFINE_string(op_select_implmode, "",
  125. "Optional; op select implmode! "
  126. "Support high_precision, high_performance.");
  127. DEFINE_string(optypelist_for_implmode, "",
  128. "Optional; Nodes need use implmode selected in op_select_implmode "
  129. "Format:\"node_name1,node_name2\"");
  130. DEFINE_string(singleop, "", "Optional; If set, generate single op model with the given json file.");
  131. DEFINE_int32(disable_reuse_memory, 0, "Optional; If set to 1, disable reuse memory when generating if.");
  132. DEFINE_string(auto_tune_mode, "", "Optional; Set tune mode.");
  133. DEFINE_string(soc_version, "", "The soc version.");
  134. DEFINE_string(core_type, "AiCore", "Optional; If set to VectorCore, only use vector core.");
  135. DEFINE_string(aicore_num, "", "Optional; Set aicore num");
  136. DEFINE_string(buffer_optimize, "l2_optimize", "Optional; buffer optimize");
  137. DEFINE_string(fusion_switch_file, "", "Optional; Set fusion switch file path");
  138. DEFINE_string(save_original_model, "", "Optional; enable output original offline model. false(default)");
  139. DEFINE_string(dynamic_batch_size, "",
  140. "Optional; If set, generate dynamic multi batch model. "
  141. "Different batch sizes are split by ','."
  142. "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one.");
  143. DEFINE_string(dynamic_image_size, "",
  144. "Optional; If set, generate dynamic multi image size model."
  145. "Different groups of image size are split by ';',"
  146. "while different dimensions of each group are split by ','."
  147. "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one.");
  148. DEFINE_string(dynamic_dims, "",
  149. "Optional; If set, generate dynamic input size model. "
  150. "Different groups of size are split by ';', while different dimensions of each group are split by ','."
  151. "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one.");
  152. DEFINE_string(enable_small_channel, "0", "Optional; If set to 1, small channel is enabled.");
  153. DEFINE_string(enable_compress_weight, "false",
  154. "Optional; enable compress weight. true: enable; false(default): disable");
  155. DEFINE_string(compress_weight_conf, "", "Optional; the config file to compress weight");
  156. DEFINE_string(enable_single_stream, "", "Optional; enable single stream. true: enable; false(default): disable");
  157. DEFINE_string(log, "null", "Optional; generate atc log. Support debug, info, warning, error, null");
  158. DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0.");
  159. DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug; "
  160. "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler; "
  161. "3: disable debug, and keep generating kernel file (.o and .json); 4: disable debug, "
  162. "keep generation kernel file (.o and .json) and generate the operator CCE file (.cce) "
  163. "and the UB fusion computing description file (.json)");
  164. DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass,"
  165. "multiple names can be set and separated by ','.");
  166. DEFINE_string(debug_dir, "", "Optional; the path to save the intermediate files of operator compilation");
  167. DEFINE_string(op_compiler_cache_dir, "", "Optional; the path to cache operator compilation files");
  168. DEFINE_string(op_compiler_cache_mode, "", "Optional; choose the operator compiler cache mode");
  169. DEFINE_string(mdl_bank_path, "", "Optional; model bank path");
  170. DEFINE_string(op_bank_path, "", "Optional; op bank path");
  171. DEFINE_string(display_model_info, "0", "Optional; display model info");
  172. DEFINE_string(device_id, "0", "Optional; user device id");
  173. class GFlagUtils {
  174. public:
  175. /**
  176. * @name InitGFlag
  177. * @brief initialize gflag
  178. * @return void
  179. */
  180. static void InitGFlag(int argc, char *argv[]) {
  181. // -help
  182. gflags::SetUsageMessage(
  183. "usage: ./atc <args>\n"
  184. "generate offline model example:\n"
  185. "./atc --model=./alexnet.prototxt --weight=./alexnet.caffemodel \n"
  186. "--framework=0 --output=./domi \n"
  187. "generate offline model for single op example:\n"
  188. "./atc --singleop=./op_list.json --output=./op_model \n"
  189. "===== Basic Functionality =====\n"
  190. "[General]\n"
  191. " --h/help Show this help message\n"
  192. " --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format; "
  193. "3: only pre-check; 5: convert ge dump txt file to JSON format; 6: display model info\n"
  194. "\n[Input]\n"
  195. " --model Model file\n"
  196. " --weight Weight file. Required when framework is Caffe\n"
  197. " --om The model file to be converted to json\n"
  198. " --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n"
  199. " --input_format Format of input data. E.g.: \"NCHW\"\n"
  200. " --input_shape Shape of input data. Separate multiple nodes with semicolons (;). "
  201. "Use double quotation marks (\") to enclose each argument.\n"
  202. " E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n"
  203. " --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)."
  204. "Use double quotation marks (\") to enclose each argument.\n"
  205. " E.g.: \"input_name1:[n1~n2,c1,h1,w1];input_name2:[n2,c2~c3,h2,w2]\"\n"
  206. " --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n"
  207. " --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;). "
  208. "Use double quotation marks (\") to enclose each argument.\n"
  209. " E.g.: \"imagesize1_height,imagesize1_width;imagesize2_height,imagesize2_width\"\n"
  210. " --dynamic_dims Set dynamic dims. Separate multiple nodes with semicolons (;). "
  211. "Use double quotation marks (\") to enclose each argument.\n"
  212. " E.g.: \"dims1_n1,dims1_n2;dims2_n1,dims2_n2\"\n"
  213. " --singleop Single op definition file. atc will generate offline "
  214. "model(s) for single op if --singleop is set.\n"
  215. "\n[Output]\n"
  216. " --output Output file path&name(needn't suffix, will add .om automatically). \n"
  217. " If --singleop is set, this arg specifies the directory to "
  218. "which the single op offline model will be generated\n"
  219. " --output_type Set net output type. Support FP32, FP16, UINT8. "
  220. "E.g.: FP16, indicates that all out nodes are set to FP16.\n"
  221. " \"node1:0:FP16;node2:1:FP32\", indicates setting the datatype of multiple out nodes.\n"
  222. " --check_report The pre-checking report file. Default value is: \"check_result.json\"\n"
  223. " --json The output json file path&name which is converted from a model\n"
  224. "\n[Target]\n"
  225. " --soc_version The soc version.\n"
  226. " --core_type Set core type AiCore or VectorCore. VectorCore: use vector core. "
  227. "Default value is: AiCore\n"
  228. " --aicore_num Set aicore num\n"
  229. "===== Advanced Functionality =====\n"
  230. "[Feature]\n"
  231. " --out_nodes Output nodes designated by users. Separate multiple nodes with semicolons (;)."
  232. "Use double quotation marks (\") to enclose each argument.\n"
  233. " E.g.: \"node_name1:0;node_name1:1;node_name2:0\"\n"
  234. " --input_fp16_nodes Input node datatype is fp16. Separate multiple nodes with semicolons (;). "
  235. "Use double quotation marks (\") to enclose each argument. "
  236. "E.g.: \"node_name1;node_name2\"\n"
  237. " --insert_op_conf Config file to insert new op\n"
  238. " --op_name_map Custom op name mapping file\n"
  239. " Note: A semicolon(;) cannot be included in each "
  240. "path, otherwise the resolved path will not match the expected one.\n"
  241. " --is_input_adjust_hw_layout Intput node datatype is fp16 and format is "
  242. "NC1HWC0, used with input_fp16_nodes. E.g.: \"true,true,false,true\"\n"
  243. " --is_output_adjust_hw_layout Net output node datatype is fp16 and format is "
  244. "NC1HWC0, used with out_nodes. E.g.: \"true,true,false,true\"\n"
  245. "\n[Model Tuning]\n"
  246. " --disable_reuse_memory The switch of reuse memory. Default value is : 0. "
  247. "0 means reuse memory, 1 means do not reuse memory.\n"
  248. " --fusion_switch_file Set fusion switch file path\n"
  249. " --enable_scope_fusion_passes validate the non-general scope fusion passes, "
  250. "multiple names can be set and separated by ','. E.g.: ScopePass1,ScopePass2,...\n"
  251. " --enable_single_stream Enable single stream. true: enable; false(default): disable\n"
  252. " --enable_small_channel Set enable small channel. 0(default): disable; 1: enable\n"
  253. " --enable_compress_weight Enable compress weight. true: enable; false(default): disable\n"
  254. " --compress_weight_conf Config file to compress weight\n"
  255. " --buffer_optimize Set buffer optimize. Support \"l2_optimize\" (default), "
  256. "\"l1_optimize\", \"off_optimize\"\n"
  257. " --mdl_bank_path Set the path of the custom repository generated after model tuning.\n"
  258. "\n[Operator Tuning]\n"
  259. " --op_precision_mode Set the path of operator precision mode configuration file (.ini)\n"
  260. " --precision_mode precision mode, support force_fp16(default), force_fp32, allow_mix_precision, "
  261. "allow_fp32_to_fp16, must_keep_origin_dtype.\n"
  262. " --modify_mixlist Set the path of operator mixed precision configuration file.\n"
  263. " --keep_dtype Retains the precision of certain operators in inference "
  264. "scenarios by using a configuration file.\n"
  265. " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n"
  266. " --op_bank_path Set the path of the custom repository generated after operator tuning with Auto Tune.\n"
  267. " --op_select_implmode Set op select implmode. Support high_precision, high_performance. "
  268. "default: high_performance\n"
  269. " --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n"
  270. " Separate multiple nodes with commas (,). Use double quotation marks (\") "
  271. "to enclose each argument. E.g.: \"node_name1,node_name2\"\n"
  272. " --op_debug_level Debug enable for TBE operator building.\n"
  273. " 0 (default): Disable debug; 1: Enable TBE pipe_all, "
  274. "and generate the operator CCE file and Python-CCE mapping file (.json);\n"
  275. " 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file "
  276. "(.json), and enable the CCE compiler -O0-g.\n"
  277. " 3: Disable debug, and keep generating kernel file (.o and .json)\n"
  278. " 4: Disable debug, keep generation kernel file (.o and .json) and generate the "
  279. "operator CCE file (.cce) and the UB fusion computing description file (.json)"
  280. "\n[Debug]\n"
  281. " --save_original_model Control whether to output original model. E.g.: true: output original model\n"
  282. " --log Generate log with level. Support debug, info, warning, error, null\n"
  283. " --dump_mode The switch of dump json with shape, to be used with mode 1. "
  284. "0(default): disable; 1: enable.\n"
  285. " --debug_dir Set the save path of operator compilation intermediate files.\n"
  286. "Default value: ./kernel_meta\n"
  287. " --op_compiler_cache_dir Set the save path of operator compilation cache files.\n"
  288. "Default value: $HOME/atc_data\n"
  289. " --op_compiler_cache_mode Set the operator compilation cache mode."
  290. "Options are disable(default), enable and force(force to refresh the cache)\n"
  291. " --display_model_info enable for display model info; 0(default): close display, 1: open display.");
  292. gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
  293. // Using gflags to analyze input parameters
  294. GflagsUtils::ChangeHelpFlags(FLAGS_h);
  295. gflags::HandleCommandLineHelpFlags();
  296. }
  297. static Status CheckDumpInfershapeJsonFlags() {
  298. Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight);
  299. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED,
  300. "[Check][Param:FrameWork]%d value is invalid.", FLAGS_framework);
  301. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  302. FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"),
  303. return domi::FAILED, "[Check][Param:weight]value:%s: is invalid, path can not reach.",
  304. FLAGS_weight.c_str());
  305. return domi::SUCCESS;
  306. }
  307. static Status CheckFlags() {
  308. Status ret = ge::SUCCESS;
  309. // No model file information passed in
  310. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  311. FLAGS_model == "",
  312. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"model"});
  313. ret = ge::FAILED, "[Check][Param]Input parameter[--model]'s value is empty!");
  314. // check param disable_reuse_memory
  315. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  316. ge::CheckDisableReuseMemoryParamValid(to_string(FLAGS_disable_reuse_memory)) != ge::SUCCESS,
  317. ret = ge::FAILED, "[Check][DisableReuseMemory]failed!");
  318. // check optypelist_for_implmode and op_select_implmode
  319. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  320. ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode,
  321. FLAGS_op_select_implmode) != ge::SUCCESS,
  322. ret = ge::FAILED, "[Check][ImplMode]check optypelist_for_implmode and op_select_implmode failed!");
  323. if (!FLAGS_op_precision_mode.empty() && !ge::CheckInputPathValid(FLAGS_op_precision_mode, "--op_precision_mode")) {
  324. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  325. {"op_precision_mode", FLAGS_op_precision_mode.c_str(),
  326. "path is not found"});
  327. GELOGE(ge::FAILED, "[Check][op_precision_mode] %s not found", FLAGS_op_precision_mode.c_str());
  328. ret = ge::FAILED;
  329. }
  330. if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) {
  331. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  332. {"modify_mixlist", FLAGS_modify_mixlist.c_str(),
  333. ge::kModifyMixlistError});
  334. ret = ge::FAILED;
  335. }
  336. // No output file information passed in
  337. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  338. FLAGS_mode == GEN_OM_MODEL && FLAGS_output == "",
  339. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"output"});
  340. ret = ge::FAILED, "[Check][Param]Input parameter[--output]'s value is empty!");
  341. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  342. CheckFrameWorkValid(FLAGS_framework, FLAGS_weight) != ge::SUCCESS,
  343. ret = ge::FAILED,
  344. "[Check][FrameWork] failed for input --FLAGS_framework and --FLAGS_weight invalid.");
  345. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  346. ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size,
  347. FLAGS_dynamic_dims, FLAGS_input_shape, FLAGS_input_shape_range,
  348. FLAGS_input_format, is_dynamic_input) != ge::SUCCESS,
  349. ret = ge::FAILED, "[Check][DynamicInput]dynamic size(batch size, image size or dims) invalid!");
  350. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  351. !FLAGS_insert_op_conf.empty() && !FLAGS_dynamic_dims.empty(),
  352. ErrorManager::GetInstance().ATCReportErrMessage("E10001",
  353. {"parameter", "value", "reason"},
  354. {"--insert_op_conf", FLAGS_insert_op_conf,
  355. "dynamic dims function does not support aipp"});
  356. ret = ge::FAILED, "[Check][Param]dynamic dims function does not support aipp");
  357. /**
  358. * Check the validity of the I / O file path
  359. */
  360. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  361. FLAGS_model != "" && !ge::CheckInputPathValid(FLAGS_model, "--model"), ret = ge::FAILED,
  362. "[Check][InputPath]model file %s not found!!", FLAGS_model.c_str());
  363. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  364. FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"),
  365. ret = ge::FAILED, "[Check][InputPath]weight file %s not found!!",
  366. FLAGS_weight.c_str());
  367. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  368. FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "--cal_conf"),
  369. ret = ge::FAILED, "[Check][InputPath]calibration config file %s not found!!",
  370. FLAGS_cal_conf.c_str());
  371. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  372. FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "--op_name_map"),
  373. ret = ge::FAILED, "[Check][InputPath]op config file %s not found!!",
  374. FLAGS_op_name_map.c_str());
  375. GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(FLAGS_insert_op_conf)) == ge::SUCCESS,
  376. ret = ge::FAILED, "[Check][InsertOpConf]failed!");
  377. GE_CHK_BOOL_EXEC(ge::CheckCompressWeightParamValid(
  378. FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS,
  379. ret = ge::FAILED, "[Check][CompressWeight]failed!");
  380. GE_CHK_BOOL_EXEC(ge::CheckKeepTypeParamValid(FLAGS_keep_dtype) == ge::SUCCESS,
  381. ret = ge::FAILED, "[Check][KeepType]failed!");
  382. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  383. !ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED,
  384. "[Check][OutputPath]]check_report file %s not found!!", FLAGS_check_report.c_str());
  385. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  386. FLAGS_mode == GEN_OM_MODEL && FLAGS_output != "" &&
  387. (!ge::CheckOutputPathValid(FLAGS_output, "--output") || !CheckPathWithName(FLAGS_output)),
  388. ret = ge::FAILED, "[Check][OutputPath]output path %s is not valid!!", FLAGS_output.c_str());
  389. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  390. FLAGS_save_original_model != "" &&
  391. FLAGS_save_original_model != "true" &&
  392. FLAGS_save_original_model != "false",
  393. ErrorManager::GetInstance().ATCReportErrMessage(
  394. "E10005", {"parameter", "value"}, {"save_original_model", FLAGS_save_original_model});
  395. ret = ge::FAILED,
  396. "[Check][Parameter]Input parameter[--save_original_model]'s value[%s] must be true or false.",
  397. FLAGS_save_original_model.c_str());
  398. GE_CHK_BOOL_EXEC(ge::CheckBufferOptimizeParamValid(FLAGS_buffer_optimize) == ge::SUCCESS,
  399. ret = ge::FAILED, "[Check][BufferOptimize]check output type failed!");
  400. GE_CHK_BOOL_EXEC(
  401. ge::CheckEnableSingleStreamParamValid(std::string(FLAGS_enable_single_stream)) == ge::SUCCESS,
  402. ret = ge::FAILED, "[Check][EnableSingleStream]failed!");
  403. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((FLAGS_display_model_info != "0") && (FLAGS_display_model_info != "1"),
  404. REPORT_INPUT_ERROR("E10006", std::vector<std::string>({"parameter", "value"}),
  405. std::vector<std::string>({"display_model_info", FLAGS_display_model_info}));
  406. ret = ge::FAILED, "[Check][Parameter]Input parameter[--display_model_info]'s value must be 1 or 0.");
  407. return ret;
  408. }
  409. /**
  410. * Verifying the parameters of converting model to JSON
  411. * 1. Fmk_model
  412. * 2. out_json
  413. **/
  414. static Status CheckConverJsonParamFlags() {
  415. Status ret = ge::SUCCESS;
  416. // No model path passed in
  417. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "",
  418. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"});
  419. ret = ge::FAILED,
  420. "[Check][Parameter]Input parameter[--om]'s value is empty!!");
  421. // JSON path not passed in
  422. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_json == "",
  423. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"json"});
  424. ret = ge::FAILED,
  425. "[Check][Parameter]Input parameter[--json]'s value is empty!!");
  426. // Check if the model path is valid
  427. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  428. FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"),
  429. ret = ge::FAILED,
  430. "[Check][InputPath]model file path is invalid: %s.", FLAGS_om.c_str());
  431. // Check whether the JSON path is valid
  432. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  433. FLAGS_json != "" && !ge::CheckOutputPathValid(FLAGS_json, "--json"),
  434. ret = ge::FAILED,
  435. "[Check][OutputPath]json file path is invalid: %s.", FLAGS_json.c_str());
  436. return ret;
  437. }
  438. /**
  439. * Check command line parameters for explicit settings
  440. * true: Explicit setup
  441. * false: Not set up
  442. * */
  443. static bool CheckFlagSet(string flag) {
  444. gflags::CommandLineFlagInfo info;
  445. return !(gflags::GetCommandLineFlagInfo(flag.c_str(), &info) && info.is_default);
  446. }
  447. private:
  448. static bool CheckEncryptModeValid(const int encrypt_mode) {
  449. #if !defined(__ANDROID__) && !defined(ANDROID)
  450. if (encrypt_mode != 0 && encrypt_mode != -1) {
  451. DOMI_LOGE("encrypt mode must be 0 or -1");
  452. return false;
  453. }
  454. #else
  455. if (encrypt_mode != -1) {
  456. DOMI_LOGE("encrypt mode must be -1");
  457. return false;
  458. }
  459. #endif
  460. return true;
  461. }
  462. static Status CheckFrameWorkValid(int framework, const std::string weight_file) {
  463. if (framework != (int32_t)domi::CAFFE && framework != (int32_t)domi::TENSORFLOW &&
  464. framework != (int32_t)domi::MINDSPORE && framework != (int32_t)domi::ONNX) {
  465. // No framework information was passed in or the entered framework is illegal
  466. ErrorManager::GetInstance().ATCReportErrMessage(
  467. "E10007", {"parameter", "support"},
  468. {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)"});
  469. DOMI_LOGE("Input parameter[--framework] is mandatory and it's value must be: "
  470. "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx).");
  471. return domi::PARAM_INVALID;
  472. }
  473. if ((framework == (int32_t)domi::CAFFE) && (weight_file == "")) {
  474. ErrorManager::GetInstance().ATCReportErrMessage("E10008", {"parameter"}, {"weight"});
  475. DOMI_LOGE("Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!");
  476. return domi::PARAM_INVALID;
  477. }
  478. if ((framework == (int32_t)domi::TENSORFLOW) && (weight_file != "")) {
  479. GELOGW("Parameter weight is ignored for TensorFlow.");
  480. }
  481. if ((framework == (int32_t)domi::ONNX) && (weight_file != "")) {
  482. GELOGW("Parameter weight is ignored for Onnx.");
  483. }
  484. return domi::SUCCESS;
  485. }
  486. static bool CheckPathWithName(const std::string &fileName) {
  487. // Determine file path length
  488. if (fileName.size() > static_cast<int>(PATH_MAX)) {
  489. ErrorManager::GetInstance().ATCReportErrMessage(
  490. "E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)});
  491. GELOGE(ge::FAILED,
  492. "[Check][Path]Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX);
  493. return false;
  494. }
  495. // Find the last separator
  496. int slashPosition = fileName.size() - 1;
  497. for (; slashPosition >= 0; slashPosition--) {
  498. if (fileName[slashPosition] == '\\' || fileName[slashPosition] == '/') {
  499. break;
  500. }
  501. }
  502. // Failure if no filename follows the path
  503. if (slashPosition == static_cast<int>(fileName.size() - 1)) {
  504. ErrorManager::GetInstance().ATCReportErrMessage("E10022", {"parameter", "filename"}, {"output", fileName});
  505. DOMI_LOGE("Input parameter[--output]'s path[%s] not include file name", fileName.c_str());
  506. return false;
  507. }
  508. return true;
  509. }
  510. };
  511. void SetDynamicInputSizeOptions() {
  512. if (!FLAGS_dynamic_batch_size.empty()) {
  513. domi::GetContext().dynamic_batch_size = FLAGS_dynamic_batch_size;
  514. }
  515. if (!FLAGS_dynamic_image_size.empty()) {
  516. domi::GetContext().dynamic_image_size = FLAGS_dynamic_image_size;
  517. }
  518. if (!FLAGS_dynamic_dims.empty()) {
  519. domi::GetContext().dynamic_dims = FLAGS_dynamic_dims;
  520. }
  521. }
  522. /// Validate the non-general scope fusion pass.
  523. /// The parameter is set to the name of the fusion rule.
  524. /// Multiple names can be set and separated by ",".
  525. void SetEnableScopeFusionPasses(const std::string pass_names) {
  526. ge::GetParserContext().enable_scope_fusion_passes = pass_names;
  527. }
  528. static bool CheckInputFormat() {
  529. if (FLAGS_input_format.empty()) {
  530. // Set default format
  531. if (FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW)) {
  532. FLAGS_input_format = "NHWC";
  533. } else {
  534. FLAGS_input_format = "NCHW";
  535. }
  536. return true;
  537. } else if ((FLAGS_framework == static_cast<int32_t>(domi::CAFFE))) { // caffe
  538. if (ge::caffe_support_input_format.find(FLAGS_input_format) != ge::caffe_support_input_format.end()) {
  539. return true;
  540. }
  541. // only support NCHW ND
  542. ErrorManager::GetInstance().ATCReportErrMessage(
  543. "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport});
  544. GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
  545. FLAGS_input_format.c_str(), kCaffeFormatSupport);
  546. return false;
  547. } else if ((FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW))) { // tf
  548. if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) {
  549. return true;
  550. }
  551. // only support NCHW NHWC ND NCDHW NDHWC
  552. ErrorManager::GetInstance().ATCReportErrMessage(
  553. "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport});
  554. GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
  555. FLAGS_input_format.c_str(), kTFFormatSupport);
  556. return false;
  557. } else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) {
  558. if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) {
  559. return true;
  560. }
  561. // only support NCHW ND
  562. ErrorManager::GetInstance().ATCReportErrMessage(
  563. "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport});
  564. GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
  565. FLAGS_input_format.c_str(), kONNXFormatSupport);
  566. return false;
  567. }
  568. return true;
  569. }
  570. #if !defined(__ANDROID__) && !defined(ANDROID)
  571. static void GetCustomOpPath(std::string &customop_path) {
  572. GELOGI("Enter get custom op path schedule");
  573. std::string fmk_type = ge::TypeUtils::FmkTypeToSerialString(static_cast<domi::FrameworkType>(FLAGS_framework));
  574. GELOGI("Framework type is %s.", fmk_type.c_str());
  575. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  576. if (path_env != nullptr) {
  577. std::string path = path_env;
  578. customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type);
  579. GELOGI("Get custom so path from env : %s", path_env);
  580. return;
  581. }
  582. std::string path_base = ge::GELib::GetPath();
  583. GELOGI("path_base is %s", path_base.c_str());
  584. path_base = path_base.substr(0, path_base.rfind('/'));
  585. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  586. customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type);
  587. return;
  588. }
  589. void GetPluginSoFileList(const string &path, vector<string> &fileList, string &caffe_parser_path) {
  590. // Support to split multiple so directories by ":"
  591. GELOGI("path is %s", path.c_str());
  592. vector<string> v_path = ge::StringUtils::Split(path, ':');
  593. for (size_t i = 0; i < v_path.size(); ++i) {
  594. ge::FindParserSo(v_path[i], fileList, caffe_parser_path);
  595. GELOGI("CustomOpLib full name = %s", v_path[i].c_str());
  596. }
  597. }
  598. void LoadModelParserLib(std::string caffe_parser_path) {
  599. if (FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW)) {
  600. void *tf_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL);
  601. if (tf_handle == nullptr) {
  602. GELOGW("dlopen fmk library [libfmk_parser.so] failed.");
  603. return;
  604. }
  605. GELOGI("plugin load libfmk_parser.so success.");
  606. } else if (FLAGS_framework == static_cast<int32_t>(domi::CAFFE)) {
  607. // What we are dealing with here is that the user modifies the caffe.proto scenario.
  608. // If no lib_Caffe_Parser.so is found under the plugin path, use the default lib_Caffe_Parser.so path.
  609. caffe_parser_path = caffe_parser_path.empty() ? "lib_caffe_parser.so" : caffe_parser_path;
  610. void *handle = dlopen(caffe_parser_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
  611. if (handle == nullptr) {
  612. GELOGW("dlopen failed, plugin name:%s. Message(%s).", caffe_parser_path.c_str(), dlerror());
  613. return;
  614. }
  615. GELOGI("plugin load %s success.", caffe_parser_path.c_str());
  616. // According to the dependency, the Caffe parsing module of the framework is loaded here( libfmk_parser.so).
  617. // (depend on the lib_caffe_parser.so)
  618. void *fmk_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL);
  619. if (fmk_handle == nullptr) {
  620. GELOGW("dlopen fmk library [libfmk_parser.so] failed.");
  621. if (dlclose(handle) != 0) {
  622. GELOGW("dlclose lib_caffe_parser.so failed.");
  623. }
  624. return;
  625. }
  626. GELOGI("plugin load libfmk_parser.so success.");
  627. } else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) {
  628. void *handle = dlopen("libfmk_onnx_parser.so", RTLD_NOW | RTLD_GLOBAL);
  629. if (handle == nullptr) {
  630. GELOGW("dlopen fmk library [libfmk_onnx_parser.so] failed.");
  631. return;
  632. }
  633. GELOGI("plugin load libfmk_onnx_parser.so success.");
  634. } else {
  635. GELOGW("Framework:%s is not support.",
  636. ge::TypeUtils::FmkTypeToSerialString(static_cast<domi::FrameworkType>(FLAGS_framework)).c_str());
  637. return;
  638. }
  639. return;
  640. }
  641. void LoadCustomOpLib(bool need_load_ops_plugin) {
  642. std::string plugin_path;
  643. GetCustomOpPath(plugin_path);
  644. vector<string> fileList;
  645. string caffe_parser_path = "";
  646. // whether there are files in the plugin so path
  647. GetPluginSoFileList(plugin_path, fileList, caffe_parser_path);
  648. // no file
  649. if (fileList.empty() && caffe_parser_path.empty()) {
  650. GELOGW("can not find any plugin file in plugin_path: %s", plugin_path.c_str());
  651. }
  652. LoadModelParserLib(caffe_parser_path);
  653. if (!need_load_ops_plugin) {
  654. GELOGI("No need to load ops plugin so.");
  655. return;
  656. }
  657. OpRegistry::Instance()->registrationDatas.clear();
  658. // load other so files except lib_caffe_parser.so in the plugin so path
  659. for (auto elem : fileList) {
  660. ge::StringUtils::Trim(elem);
  661. void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL);
  662. if (handle == nullptr) {
  663. GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror());
  664. } else {
  665. GELOGI("plugin load %s success.", elem.c_str());
  666. }
  667. }
  668. std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas;
  669. for (OpRegistrationData reg_data : registrationDatas) {
  670. if (reg_data.GetFrameworkType() == static_cast<domi::FrameworkType>(FLAGS_framework)) {
  671. (void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data);
  672. (void)OpRegistry::Instance()->Register(reg_data);
  673. }
  674. }
  675. }
  676. void SaveCustomCaffeProtoPath() {
  677. GELOGI("Enter save custom caffe proto path.");
  678. std::string path_base = ge::GELib::GetPath();
  679. GELOGI("path_base is %s", path_base.c_str());
  680. path_base = path_base.substr(0, path_base.rfind('/'));
  681. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  682. ge::GetParserContext().caffe_proto_path = path_base + "include/proto/";
  683. string customop_path;
  684. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  685. if (path_env != nullptr) {
  686. std::string path = path_env;
  687. customop_path = path + "/framework/custom/caffe/";
  688. GELOGI("Get custom proto path from env : %s", path_env);
  689. ge::GetParserContext().custom_proto_path = customop_path;
  690. return;
  691. }
  692. customop_path = path_base + "ops/framework/custom/caffe/";
  693. ge::GetParserContext().custom_proto_path = customop_path;
  694. return;
  695. }
  696. #endif
  697. Status CreateInputsForInference(const ge::Graph &graph, vector<ge::GeTensor> &inputs) {
  698. auto compute_graph = ge::GraphUtils::GetComputeGraph(graph);
  699. GE_CHECK_NOTNULL(compute_graph);
  700. int64_t index = 0;
  701. for (ge::NodePtr &input_node : compute_graph->GetAllNodes()) {
  702. GE_CHECK_NOTNULL(input_node);
  703. ge::OpDescPtr op = input_node->GetOpDesc();
  704. GE_CHECK_NOTNULL(op);
  705. if (op->GetType() == ge::DATA) {
  706. if (!op->HasAttr(ge::ATTR_NAME_INDEX)) {
  707. (void)ge::AttrUtils::SetInt(op, ge::ATTR_NAME_INDEX, index);
  708. GELOGD("Set attr index:%ld for data op:%s", index, op->GetName().c_str());
  709. }
  710. index++;
  711. GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size());
  712. ge::GeTensorDesc tensor = op->GetInputDesc(0);
  713. string data_op_name = op->GetName();
  714. GELOGI("Data op name is: %s", data_op_name.c_str());
  715. ge::GeShape data_shape;
  716. auto iter = domi::GetContext().input_dims.find(data_op_name);
  717. if (iter != domi::GetContext().input_dims.end()) {
  718. data_shape = ge::GeShape(iter->second);
  719. GELOGI("Data op get shape from Context.");
  720. } else {
  721. data_shape = tensor.GetShape();
  722. GELOGI("Data op get shape from InputDesc in geir graph.");
  723. }
  724. ge::DataType data_type = tensor.GetDataType();
  725. string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type);
  726. GELOGI("Data op get data type:%s from InputDesc in geir graph.", data_type_str.c_str());
  727. ge::GeTensor input_tensor;
  728. ge::GeTensorDesc desc(data_shape, ge::Format(domi::GetContext().format), data_type);
  729. input_tensor.SetTensorDesc(desc);
  730. inputs.push_back(input_tensor);
  731. }
  732. }
  733. GELOGI("Build ME model, inputs size is: %zu", inputs.size());
  734. return ge::SUCCESS;
  735. }
  736. domi::Status GenerateInfershapeJson() {
  737. if (!CheckInputFormat()) {
  738. GELOGE(ge::FAILED, "[Check][InputFormat] failed.");
  739. return domi::FAILED;
  740. }
  741. Status ret = GFlagUtils::CheckDumpInfershapeJsonFlags();
  742. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "[Check][DumpInfershapeJsonFlags] failed!");
  743. ge::GeGenerator ge_generator;
  744. std::map<string, string> options;
  745. ge::Status geRet = ge_generator.Initialize(options, domi::GetContext());
  746. if (geRet != ge::SUCCESS) {
  747. DOMI_LOGE("GeGenerator initialize failed!");
  748. return domi::FAILED;
  749. }
  750. ge::Graph graph;
  751. std::map<string, string> atc_params;
  752. atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format));
  753. atc_params.insert(std::pair<string, string>("check_report", FLAGS_check_report));
  754. ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework,
  755. "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false);
  756. if (ret != ge::SUCCESS) {
  757. DOMI_LOGE("ATC Parse graph domi::FAILED");
  758. (void)ge_generator.Finalize();
  759. return domi::FAILED;
  760. }
  761. geRet = ge_generator.GenerateInfershapeGraph(graph);
  762. if (geRet != ge::SUCCESS) {
  763. DOMI_LOGE("ATC GenerateInfershapeJson failed");
  764. (void)ge_generator.Finalize();
  765. return domi::FAILED;
  766. }
  767. if (DumpInfershapeJson(graph, FLAGS_json.c_str()) != SUCCESS) {
  768. DOMI_LOGE("ATC DumpInfershapeJson failed");
  769. (void)ge_generator.Finalize();
  770. return domi::FAILED;
  771. }
  772. (void)ge_generator.Finalize();
  773. return ge::SUCCESS;
  774. }
  775. static Status ConvertModelToJson(int fwk_type, const string &model_file, const string &json_file) {
  776. Status ret = ge::SUCCESS;
  777. if (fwk_type == -1) {
  778. ret = ge::ConvertOm(model_file.c_str(), json_file.c_str(), true);
  779. return ret;
  780. }
  781. if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE) && (fwk_type != domi::ONNX)) {
  782. ErrorManager::GetInstance().ATCReportErrMessage(
  783. "E10001", {"parameter", "value", "reason"},
  784. {"--framework", std::to_string(fwk_type), kModelToJsonSupport});
  785. GELOGE(ge::FAILED, "[Convert][ModelToJson]Invalid value for --framework[%d], %s.",
  786. fwk_type, kModelToJsonSupport);
  787. ret = ge::FAILED;
  788. }
  789. if (FLAGS_dump_mode != "0" && FLAGS_dump_mode != "1") {
  790. REPORT_INPUT_ERROR("E10006", std::vector<std::string>({"parameter", "value"}),
  791. std::vector<std::string>({"dump_mode", FLAGS_dump_mode}));
  792. GELOGE(ge::FAILED, "[Convert][ModelToJson] Input parameter[--dump_mode]'s value must be 1 or 0.");
  793. ret = ge::FAILED;
  794. }
  795. if (ret != ge::SUCCESS) return ret;
  796. // Need to save caffe.proto path
  797. SaveCustomCaffeProtoPath();
  798. if (FLAGS_dump_mode == "0") {
  799. // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so.
  800. LoadCustomOpLib(false);
  801. ret = ge::ConvertFwkModelToJson((domi::FrameworkType)fwk_type, model_file.c_str(), json_file.c_str());
  802. } else if (FLAGS_dump_mode == "1") {
  803. // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so and ops plugin so.
  804. LoadCustomOpLib(true);
  805. ret = GenerateInfershapeJson();
  806. }
  807. return ret;
  808. }
  809. static Status SetAttrOptions(ge::Graph &graph) {
  810. if (!FLAGS_keep_dtype.empty()) {
  811. if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_KEEP_DTYPE, FLAGS_keep_dtype.c_str()) != ge::GRAPH_SUCCESS) {
  812. return ge::FAILED;
  813. }
  814. }
  815. if (!FLAGS_compress_weight_conf.empty()) {
  816. if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_WEIGHT_COMPRESS, FLAGS_compress_weight_conf.c_str())
  817. != ge::GRAPH_SUCCESS) {
  818. return ge::FAILED;
  819. }
  820. }
  821. return ge::SUCCESS;
  822. }
  823. domi::Status GenerateModel(std::map<string, string> &options, std::string output) {
  824. ge::GeGenerator ge_generator;
  825. ge::Status geRet = ge::SUCCESS;
  826. std::shared_ptr<ge::GELib> instance_ptr = ge::GELib::GetInstance();
  827. if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
  828. geRet = ge::GELib::Initialize(options);
  829. if (geRet != ge::SUCCESS) {
  830. DOMI_LOGE("GE initialize failed!");
  831. return domi::FAILED;
  832. }
  833. }
  834. geRet = ge_generator.Initialize(options, domi::GetContext());
  835. if (geRet != ge::SUCCESS) {
  836. DOMI_LOGE("GeGenerator initialize failed!");
  837. (void)ge::GELib::GetInstance()->Finalize();
  838. return domi::FAILED;
  839. }
  840. ge::Graph graph;
  841. std::vector<ge::GeTensor> inputs;
  842. if (FLAGS_framework == domi::MINDSPORE) {
  843. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  844. // load model from file
  845. ge::Model load_model = ge::Model("loadmodel", "version2");
  846. auto ret1 = load_model.LoadFromFile(FLAGS_model);
  847. if (ret1 != ge::GRAPH_SUCCESS) {
  848. REPORT_INPUT_ERROR("E10041", std::vector<std::string>({"parameter"}), std::vector<std::string>({FLAGS_model}));
  849. DOMI_LOGE("Load model from %s failed, please check model file or "
  850. "input parameter[--framework] is correct", FLAGS_model.c_str());
  851. (void)ge_generator.Finalize();
  852. (void)ge::GELib::GetInstance()->Finalize();
  853. return domi::FAILED;
  854. }
  855. graph = load_model.GetGraph();
  856. GE_CHK_STATUS_EXEC(ge::InitDomiOmgContext(FLAGS_input_shape, FLAGS_input_format, "", is_dynamic_input),
  857. GELOGE(ge::FAILED, "[Init][DomiOmgContext]ATC Generate call InitDomiOmgContext ret fail");
  858. (void)ge_generator.Finalize(); (void)ge::GELib::GetInstance()->Finalize(); return domi::FAILED);
  859. Status ret = CreateInputsForInference(graph, inputs);
  860. if (ret != ge::SUCCESS) {
  861. GELOGE(ge::FAILED, "[Create][InputsForInference] failed.");
  862. REPORT_CALL_ERROR("E19999", "CreateInputsForInference failed for input --graph and --inputs.");
  863. (void)ge_generator.Finalize();
  864. (void)ge::GELib::GetInstance()->Finalize();
  865. return domi::FAILED;
  866. }
  867. } else {
  868. std::map<string, string> atc_params;
  869. atc_params.insert(std::pair<string, string>("input_shape", FLAGS_input_shape));
  870. atc_params.insert(std::pair<string, string>(ge::INPUT_SHAPE_RANGE, FLAGS_input_shape_range));
  871. atc_params.insert(std::pair<string, string>("out_nodes", FLAGS_out_nodes));
  872. atc_params.insert(std::pair<string, string>("input_format", FLAGS_input_format));
  873. atc_params.insert(std::pair<string, string>("check_report", FLAGS_check_report));
  874. atc_params.insert(std::pair<string, string>("input_fp16_nodes", FLAGS_input_fp16_nodes));
  875. atc_params.insert(std::pair<string, string>("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout));
  876. atc_params.insert(std::pair<string, string>("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout));
  877. atc_params.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type));
  878. atc_params.insert(std::pair<string, string>("output", output));
  879. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  880. Status ret =
  881. ParseGraph(graph, atc_params, FLAGS_model.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType)FLAGS_framework,
  882. FLAGS_op_name_map.c_str(), FLAGS_target.c_str(), (ge::RunMode)FLAGS_mode, is_dynamic_input);
  883. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  884. // in ONLY_PRE_CHECK mode, pre-checking report has already saved in ParseGraph
  885. if (FLAGS_mode == ge::ONLY_PRE_CHECK) {
  886. (void)ge_generator.Finalize();
  887. (void)ge::GELib::GetInstance()->Finalize();
  888. if (ret != ge::SUCCESS) {
  889. DOMI_LOGE("ATC precheck fail.");
  890. return domi::FAILED;
  891. }
  892. return domi::SUCCESS;
  893. }
  894. if (ret != ge::SUCCESS) {
  895. DOMI_LOGE("ATC Parse graph domi::FAILED");
  896. DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case
  897. (void)ge_generator.Finalize();
  898. (void)ge::GELib::GetInstance()->Finalize();
  899. return domi::FAILED;
  900. }
  901. if (ge::SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) {
  902. DOMI_LOGE("Set output node info fail.");
  903. (void)ge_generator.Finalize();
  904. (void)ge::GELib::GetInstance()->Finalize();
  905. return domi::FAILED;
  906. }
  907. }
  908. if (SetAttrOptions(graph) != ge::SUCCESS) {
  909. (void)ge_generator.Finalize();
  910. (void)ge::GELib::GetInstance()->Finalize();
  911. return domi::FAILED;
  912. }
  913. geRet = ge_generator.GenerateOfflineModel(graph, output, inputs);
  914. if (geRet != ge::SUCCESS) {
  915. DOMI_LOGE("GE GenerateOfflineModel execute failed");
  916. DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case
  917. // checking error log)
  918. (void)ge_generator.Finalize();
  919. (void)ge::GELib::GetInstance()->Finalize();
  920. return domi::FAILED;
  921. }
  922. (void)ge_generator.Finalize();
  923. (void)ge::GELib::GetInstance()->Finalize();
  924. return ge::SUCCESS;
  925. }
  926. static void SetEnvForSingleOp(std::map<string, string> &options) {
  927. string flag_on = "1";
  928. string flag_off = "0";
  929. options.emplace(ge::GE_FE_FLAG, flag_on);
  930. options.emplace(ge::STREAM_NUM, "1"); // single op only use one stream
  931. options.emplace(ge::RUN_FLAG, flag_off);
  932. options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off);
  933. options.emplace(ge::SINGLE_OP_FLAG, flag_on);
  934. options.emplace(ge::OP_PRECISION_MODE, FLAGS_op_precision_mode);
  935. options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode);
  936. options.emplace(ge::SOC_VERSION, FLAGS_soc_version);
  937. options.emplace(ge::CORE_TYPE, FLAGS_core_type);
  938. options.emplace(ge::AICORE_NUM, FLAGS_aicore_num);
  939. options.emplace(ge::OP_SELECT_IMPL_MODE, FLAGS_op_select_implmode);
  940. options.emplace(ge::OPTYPELIST_FOR_IMPLMODE, FLAGS_optypelist_for_implmode);
  941. options.emplace(ge::AUTO_TUNE_MODE, FLAGS_auto_tune_mode);
  942. options.emplace(ge::OP_DEBUG_LEVEL, to_string(FLAGS_op_debug_level));
  943. options.emplace(ge::DEBUG_DIR, FLAGS_debug_dir);
  944. options.emplace(ge::OP_COMPILER_CACHE_DIR, FLAGS_op_compiler_cache_dir);
  945. options.emplace(ge::OP_COMPILER_CACHE_MODE, FLAGS_op_compiler_cache_mode);
  946. options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path);
  947. options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path);
  948. options.emplace(ge::TUNE_DEVICE_IDS, FLAGS_device_id);
  949. options.emplace(ge::MODIFY_MIXLIST, FLAGS_modify_mixlist);
  950. }
  951. domi::Status GenerateSingleOp(const std::string& json_file_path) {
  952. if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) {
  953. DOMI_LOGE("output path %s is not valid!", FLAGS_output.c_str());
  954. return domi::FAILED;
  955. }
  956. // check optypelist_for_implmode and op_select_implmode
  957. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  958. ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS,
  959. return ge::FAILED, "[Check][ImplmodeParam] fail for input optypelist_for_implmode and op_select_implmode.");
  960. if (!FLAGS_op_precision_mode.empty() && !ge::CheckInputPathValid(FLAGS_op_precision_mode, "--op_precision_mode")) {
  961. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  962. {"op_precision_mode", FLAGS_op_precision_mode.c_str(),
  963. "path is not found"});
  964. GELOGE(ge::FAILED, "[Check][op_precision_mode] %s not found", FLAGS_op_precision_mode.c_str());
  965. return ge::FAILED;
  966. }
  967. if (ge::CheckModifyMixlistParamValid(FLAGS_precision_mode, FLAGS_modify_mixlist) != ge::SUCCESS) {
  968. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  969. {"modify_mixlist", FLAGS_modify_mixlist.c_str(),
  970. ge::kModifyMixlistError});
  971. return ge::FAILED;
  972. }
  973. std::map<string, string> options;
  974. // need to be changed when ge.ini plan is done
  975. SetEnvForSingleOp(options);
  976. // print single op option map
  977. ge::PrintOptionMap(options, "single op option");
  978. auto ret = ge::GELib::Initialize(options);
  979. if (ret != ge::SUCCESS) {
  980. DOMI_LOGE("GE initialize failed!");
  981. return domi::FAILED;
  982. }
  983. ge::GeGenerator generator;
  984. ret = generator.Initialize(options, domi::GetContext());
  985. if (ret != SUCCESS) {
  986. DOMI_LOGE("GeGenerator initialize failed!");
  987. (void)ge::GELib::GetInstance()->Finalize();
  988. return domi::FAILED;
  989. }
  990. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
  991. vector<ge::SingleOpBuildParam> build_params;
  992. if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) {
  993. DOMI_LOGE("parse single op json file failed");
  994. (void)generator.Finalize();
  995. (void)ge::GELib::GetInstance()->Finalize();
  996. return domi::FAILED;
  997. }
  998. int index = 0;
  999. for (auto &param : build_params) {
  1000. string output_path;
  1001. if (!FLAGS_output.empty()) {
  1002. output_path = FLAGS_output + "/";
  1003. }
  1004. output_path += param.file_name;
  1005. ret = generator.BuildSingleOpModel(param.op_desc, param.inputs, param.outputs, output_path, param.compile_flag);
  1006. if (ret != SUCCESS) {
  1007. DOMI_LOGE("Compile op failed. ge ret = %u, op index = %d", ret, index);
  1008. ret = domi::FAILED;
  1009. } else {
  1010. GELOGI("Compile op success. op index = %d, output = %s", index, output_path.c_str());
  1011. }
  1012. index += 1;
  1013. }
  1014. (void)generator.Finalize();
  1015. (void)ge::GELib::GetInstance()->Finalize();
  1016. return ret;
  1017. }
  1018. domi::Status GenerateOmModel() {
  1019. if (!CheckInputFormat()) {
  1020. GELOGE(ge::FAILED, "[Check][InputFormat]failed.");
  1021. return domi::FAILED;
  1022. }
  1023. Status ret = GFlagUtils::CheckFlags();
  1024. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED,
  1025. "[Check][Flags] failed! Please check whether some atc params that include semicolons[;] use double "
  1026. "quotation marks (\") to enclose each argument such as out_nodes, input_shape, dynamic_image_size");
  1027. #if !defined(__ANDROID__) && !defined(ANDROID)
  1028. // Load custom operator Library
  1029. LoadCustomOpLib(true);
  1030. SaveCustomCaffeProtoPath();
  1031. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "[Check][Flags]check custom aicpu run so failed!");
  1032. #endif
  1033. const int f_stream_num = 1;
  1034. std::map<string, string> options;
  1035. options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(FLAGS_framework)));
  1036. options.insert(std::pair<string, string>(string(ge::STREAM_NUM), to_string(f_stream_num)));
  1037. options.insert(std::pair<string, string>(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf));
  1038. options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes));
  1039. options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf));
  1040. options.insert(std::pair<string, string>(string(ge::OP_PRECISION_MODE), FLAGS_op_precision_mode));
  1041. options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode));
  1042. options.insert(std::pair<string, string>(string(ge::TUNE_DEVICE_IDS), FLAGS_device_id));
  1043. options.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0)));
  1044. options.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0)));
  1045. if (!FLAGS_output_type.empty()) {
  1046. options.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type));
  1047. }
  1048. options.insert(std::pair<string, string>(string(ge::OP_SELECT_IMPL_MODE), FLAGS_op_select_implmode));
  1049. options.insert(std::pair<string, string>(string(ge::OPTYPELIST_FOR_IMPLMODE), FLAGS_optypelist_for_implmode));
  1050. if (!FLAGS_input_fp16_nodes.empty()) {
  1051. GELOGI("FLAGS_input_fp16_nodes : %s .", FLAGS_input_fp16_nodes.c_str());
  1052. options.insert(std::pair<string, string>(ge::INPUT_FP16_NODES, FLAGS_input_fp16_nodes));
  1053. }
  1054. options.insert(std::pair<string, string>(string(ge::AUTO_TUNE_MODE), FLAGS_auto_tune_mode));
  1055. options.insert(
  1056. std::pair<string, string>(string(ge::OPTION_EXEC_DISABLE_REUSED_MEMORY), to_string(FLAGS_disable_reuse_memory)));
  1057. options.insert(std::pair<string, string>(string(ge::SOC_VERSION), FLAGS_soc_version));
  1058. options.insert(std::pair<string, string>(string(ge::CORE_TYPE), FLAGS_core_type));
  1059. options.insert(std::pair<string, string>(string(ge::AICORE_NUM), FLAGS_aicore_num));
  1060. options.insert(std::pair<string, string>(string(ge::BUFFER_OPTIMIZE), FLAGS_buffer_optimize));
  1061. options.insert(std::pair<string, string>(string(ge::ENABLE_SMALL_CHANNEL), FLAGS_enable_small_channel));
  1062. options.insert(std::pair<string, string>(string(ge::FUSION_SWITCH_FILE), FLAGS_fusion_switch_file));
  1063. options.insert(std::pair<string, string>(string(ge::ENABLE_COMPRESS_WEIGHT),
  1064. (FLAGS_enable_compress_weight == "true") ?
  1065. ge::kEnableCompressWeightTrue : ge::kEnableCompressWeightFalse));
  1066. options.insert(std::pair<string, string>(string(ge::ENABLE_SINGLE_STREAM), FLAGS_enable_single_stream));
  1067. options.insert(std::pair<string, string>(string(ge::DEBUG_DIR), FLAGS_debug_dir));
  1068. options.insert(std::pair<string, string>(string(ge::OP_COMPILER_CACHE_DIR), FLAGS_op_compiler_cache_dir));
  1069. options.insert(std::pair<string, string>(string(ge::OP_COMPILER_CACHE_MODE), FLAGS_op_compiler_cache_mode));
  1070. SetDynamicInputSizeOptions();
  1071. if (!FLAGS_save_original_model.empty()) {
  1072. options.insert(std::pair<string, string>(string(ge::SAVE_ORIGINAL_MODEL), FLAGS_save_original_model));
  1073. options.insert(std::pair<string, string>(string(ge::ORIGINAL_MODEL_FILE), FLAGS_output + "_original.om"));
  1074. }
  1075. options.insert(std::pair<string, string>(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level)));
  1076. options.insert(std::pair<string, string>(string(ge::MDL_BANK_PATH_FLAG), FLAGS_mdl_bank_path));
  1077. options.insert(std::pair<string, string>(string(ge::OP_BANK_PATH_FLAG), FLAGS_op_bank_path));
  1078. options.insert(std::pair<string, string>(string(ge::DISPLAY_MODEL_INFO), FLAGS_display_model_info));
  1079. options.insert(std::pair<string, string>(string(ge::MODIFY_MIXLIST), FLAGS_modify_mixlist));
  1080. // set enable scope fusion passes
  1081. SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes);
  1082. // print atc option map
  1083. ge::PrintOptionMap(options, "atc option");
  1084. // When the ATC module is transferred to a model, the suffix ".om" is automatically added to the model name
  1085. FLAGS_output = FLAGS_output + ".om";
  1086. ret = GenerateModel(options, FLAGS_output);
  1087. if (ret != domi::SUCCESS) {
  1088. return domi::FAILED;
  1089. }
  1090. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  1091. if (FLAGS_display_model_info == "1") {
  1092. GELOGI("need to display model info.");
  1093. return ge::ConvertOm(FLAGS_output.c_str(), "", false);
  1094. }
  1095. return domi::SUCCESS;
  1096. }
  1097. domi::Status ConvertModelToJson() {
  1098. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  1099. Status ret = GFlagUtils::CheckConverJsonParamFlags();
  1100. GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "[CheckConver][JsonParamFlags] failed!");
  1101. ret = ConvertModelToJson(FLAGS_framework, FLAGS_om, FLAGS_json);
  1102. GE_IF_BOOL_EXEC(ret != domi::SUCCESS, return domi::FAILED);
  1103. return domi::SUCCESS;
  1104. }
  1105. domi::Status DisplayModelInfo() {
  1106. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  1107. // No model path passed in
  1108. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "",
  1109. ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"});
  1110. return ge::FAILED,
  1111. "[Check][Parameter]Input parameter[--om]'s value is empty!!");
  1112. // Check if the model path is valid
  1113. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  1114. FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"),
  1115. return ge::FAILED,
  1116. "[Check][InputPath]model file path is invalid: %s.", FLAGS_om.c_str());
  1117. if (FLAGS_framework == -1) {
  1118. return ge::ConvertOm(FLAGS_om.c_str(), "", false);
  1119. }
  1120. return ge::FAILED;
  1121. }
  1122. bool CheckRet(domi::Status ret) {
  1123. if (ret != domi::SUCCESS) {
  1124. if (FLAGS_mode == ONLY_PRE_CHECK) {
  1125. GELOGW("ATC precheck failed.");
  1126. } else if (FLAGS_mode == GEN_OM_MODEL) {
  1127. GELOGW("ATC generate offline model failed.");
  1128. } else if (FLAGS_mode == MODEL_TO_JSON) {
  1129. GELOGW("ATC convert model to json file failed.");
  1130. } else if (FLAGS_mode == PBTXT_TO_JSON) {
  1131. GELOGW("ATC convert pbtxt to json file failed.");
  1132. } else {
  1133. return false;
  1134. }
  1135. return false;
  1136. }
  1137. if (FLAGS_mode == ONLY_PRE_CHECK) {
  1138. GELOGI("ATC precheck success.");
  1139. } else if (FLAGS_mode == GEN_OM_MODEL) {
  1140. GELOGI("ATC generate offline model success.");
  1141. } else if (FLAGS_mode == MODEL_TO_JSON) {
  1142. GELOGI("ATC convert model to json file success.");
  1143. } else if (FLAGS_mode == PBTXT_TO_JSON) {
  1144. GELOGI("ATC convert pbtxt to json file success.");
  1145. }
  1146. return true;
  1147. }
  1148. domi::Status ConvertPbtxtToJson() {
  1149. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  1150. Status ret = GFlagUtils::CheckConverJsonParamFlags();
  1151. if (ret != domi::SUCCESS) {
  1152. GELOGE(ge::FAILED, "[CheckConver][JsonParamFlags] failed!");
  1153. return domi::FAILED;
  1154. }
  1155. ret = ge::ConvertPbtxtToJson(FLAGS_om.c_str(), FLAGS_json.c_str());
  1156. if (ret != domi::SUCCESS) {
  1157. GELOGE(ge::FAILED, "[Convert][PbtxtToJson] fail.");
  1158. REPORT_CALL_ERROR("E19999", "ConvertPbtxtToJson failed, FLAGS_om:%s, FLAGS_json:%s.",
  1159. FLAGS_om.c_str(), FLAGS_json.c_str());
  1160. return domi::FAILED;
  1161. }
  1162. return domi::SUCCESS;
  1163. }
  1164. int init(int argc, char* argv[]) {
  1165. GFlagUtils::InitGFlag(argc, argv);
  1166. const char *gflag_argv = gflags::GetArgv();
  1167. string cmdline = gflag_argv == nullptr ? "" : gflag_argv;
  1168. domi::GetContext().atc_cmdline = cmdline;
  1169. // set log level
  1170. int ret = -1;
  1171. const std::set<string> log_level = {"null", "debug", "info", "warning", "error"};
  1172. if (log_level.count(FLAGS_log) == 0) {
  1173. std::cout << "E10010: invalid value for --log:" << FLAGS_log
  1174. <<", only support debug, info, warning, error, null"<< std::endl;
  1175. return ret;
  1176. }
  1177. ret = ge::CheckLogParamValidAndSetLogLevel(FLAGS_log);
  1178. if (ret != 0) {
  1179. return ret;
  1180. }
  1181. std::string path_base = ge::GELib::GetPath();
  1182. ret = ErrorManager::GetInstance().Init(path_base);
  1183. if (ret != 0) {
  1184. DOMI_LOGE("ErrorManager init fail !");
  1185. return ret;
  1186. }
  1187. ErrorManager::GetInstance().GenWorkStreamIdDefault();
  1188. return 0;
  1189. }
  1190. long GetMemInfo(const std::string &key) {
  1191. std::string file_path = "/proc/meminfo";
  1192. std::ifstream fs(file_path, std::ifstream::in);
  1193. if (!fs.is_open()) {
  1194. GELOGW("Can not open %s .", file_path.c_str());
  1195. return 0;
  1196. }
  1197. std::string line;
  1198. while (getline(fs, line)) { // line not with \n
  1199. if (line.find(key) != std::string::npos) {
  1200. GELOGI("Find mem [%s] info line [%s]", key.c_str(), line.c_str());
  1201. fs.close();
  1202. size_t pos = line.find(":");
  1203. if (pos == std::string::npos) {
  1204. return 0;
  1205. }
  1206. std::string current_mem_info_str = line.substr(pos + 1);
  1207. ge::StringUtils::Trim(current_mem_info_str);
  1208. GELOGI("Find mem [%s] info [%s].", key.c_str(), current_mem_info_str.c_str());
  1209. return stol(current_mem_info_str);
  1210. }
  1211. }
  1212. fs.close(); // close the file
  1213. return 0;
  1214. }
  1215. bool CheckMemInfo() {
  1216. if (FLAGS_auto_tune_mode.empty()) {
  1217. return true;
  1218. }
  1219. // only check current available mem when auto_tune_mode is set.
  1220. long current_mem_available = GetMemInfo("MemAvailable");
  1221. GELOGI("Get mem available [%lu kB].", current_mem_available);
  1222. std::cout << "Current available mem is " << current_mem_available << "kB." << std::endl;
  1223. if ((current_mem_available > 0) && (current_mem_available < kMinAvailableMem)) {
  1224. GELOGE(ge::PARAM_INVALID, "[Check][MemSize]Current available mem [%lu kB] can not be smaller than [%lu kB] .",
  1225. current_mem_available, kMinAvailableMem);
  1226. ErrorManager::GetInstance().ATCReportErrMessage("E10044", {"value", "min_value"},
  1227. {to_string(current_mem_available), to_string(kMinAvailableMem)});
  1228. return false;
  1229. }
  1230. return true;
  1231. }
  1232. int main(int argc, char* argv[]) {
  1233. ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOther);
  1234. Status ret = domi::SUCCESS;
  1235. std::cout << "ATC start working now, please wait for a moment." << std::endl;
  1236. // Initialize
  1237. if (init(argc, argv) != 0) {
  1238. std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl;
  1239. return -1;
  1240. }
  1241. do {
  1242. if (!CheckMemInfo()) {
  1243. GELOGE(ge::PARAM_INVALID, "[Check][MemInfo]Current available mem is too small.");
  1244. ret = domi::FAILED;
  1245. break;
  1246. }
  1247. if (!FLAGS_singleop.empty()) {
  1248. ret = GenerateSingleOp(FLAGS_singleop);
  1249. break;
  1250. }
  1251. // default mode(mode:0), Open source model to model
  1252. if (GEN_OM_MODEL == FLAGS_mode || ONLY_PRE_CHECK == FLAGS_mode) {
  1253. GE_IF_BOOL_EXEC(GenerateOmModel() != domi::SUCCESS, ret = domi::FAILED; break);
  1254. } else if (MODEL_TO_JSON == FLAGS_mode) { // Mode 1, transfer model to JSON
  1255. GE_CHK_BOOL_EXEC(ConvertModelToJson() == domi::SUCCESS, ret = domi::FAILED;
  1256. break, "[Convert][ModelToJson]ATC ConvertJson execute failed!!");
  1257. } else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) {
  1258. GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED;
  1259. break, "[Convert][PbtxtToJson]ATC convert pbtxt to json execute failed!!");
  1260. } else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) {
  1261. GE_CHK_BOOL_EXEC(DisplayModelInfo() == domi::SUCCESS, ret = domi::FAILED;
  1262. break, "[Display][ModelInfo]ATC DisplayModelInfo failed!!");
  1263. } else {
  1264. ErrorManager::GetInstance().ATCReportErrMessage(
  1265. "E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport});
  1266. GELOGE(ge::PARAM_INVALID, "[Check][Parameter]Invalid value for --mode[%d], %s.", FLAGS_mode, kModeSupport);
  1267. ret = domi::FAILED;
  1268. break;
  1269. }
  1270. } while (0);
  1271. ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize);
  1272. if (!CheckRet(ret)) {
  1273. std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl;
  1274. int result = ErrorManager::GetInstance().OutputErrMessage(STDOUT_FILENO);
  1275. if (result != 0) {
  1276. DOMI_LOGE("ErrorManager outputErrMessage fail !");
  1277. }
  1278. GELOGI("Current mem available mem is [%lu kB]", GetMemInfo("MemAvailable"));
  1279. return ret;
  1280. } else {
  1281. std::cout << "ATC run success, welcome to the next use." << std::endl;
  1282. (void)ErrorManager::GetInstance().OutputMessage(STDOUT_FILENO);
  1283. return 0;
  1284. }
  1285. } /*lint +e530*/

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示