You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

acl_graph_parser_util.cc 37 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. * http://www.apache.org/licenses/LICENSE-2.0
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS,
  9. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. * See the License for the specific language governing permissions and
  11. * limitations under the License.
  12. */
  13. #include "parser/common/acl_graph_parser_util.h"
  14. #include <dlfcn.h>
  15. #include <regex.h>
  16. #include <cstdlib>
  17. #include <ctime>
  18. #include <fstream>
  19. #include "common/debug/log.h"
  20. #include "common/op/ge_op_utils.h"
  21. #include "common/string_util.h"
  22. #include "common/types.h"
  23. #include "common/util.h"
  24. #include "common/util/error_manager/error_manager.h"
  25. #include "external/ge/ge_api_types.h"
  26. #include "framework/common/debug/ge_log.h"
  27. #include "framework/omg/parser/parser_types.h"
  28. #include "ge/ge_api_types.h"
  29. #include "google/protobuf/io/coded_stream.h"
  30. #include "google/protobuf/io/zero_copy_stream_impl.h"
  31. #include "graph/opsproto_manager.h"
  32. #include "graph/utils/type_utils.h"
  33. #include "omg/parser/parser_inner_ctx.h"
  34. #include "parser/common/register_tbe.h"
  35. #include "tbe_plugin_loader.h"
  36. using google::protobuf::io::CodedInputStream;
  37. using google::protobuf::io::FileInputStream;
  38. using google::protobuf::io::ZeroCopyInputStream;
  39. using namespace ge::parser;
  40. namespace {
  41. const std::string kGraphDefaultName = "domi_default";
  42. /// The maximum length of the file.
  43. /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1
  44. const int kMaxFileSizeLimit = INT_MAX;
  45. const int kMaxBuffSize = 256;
  46. const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
  47. const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M
  48. static string GetSoPath() {
  49. Dl_info dl_info;
  50. if (dladdr(reinterpret_cast<void *>(&GetSoPath), &dl_info) == 0) {
  51. GELOGW("Failed to read so_path!");
  52. return string();
  53. } else {
  54. std::string so_path = dl_info.dli_fname;
  55. char path[PATH_MAX] = {0};
  56. if (so_path.length() >= PATH_MAX) {
  57. GELOGW("File path is too long!");
  58. return string();
  59. }
  60. if (realpath(so_path.c_str(), path) == nullptr) {
  61. GELOGW("Failed to get realpath of %s", so_path.c_str());
  62. return string();
  63. }
  64. so_path = path;
  65. so_path = so_path.substr(0, so_path.rfind('/') + 1);
  66. return so_path;
  67. }
  68. }
  69. static void GetOpsProtoPath(string &opsproto_path) {
  70. GELOGD("Start to get ops proto path schedule.");
  71. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  72. if (path_env != nullptr) {
  73. string path = path_env;
  74. string file_path = ge::parser::RealPath(path.c_str());
  75. if (file_path.empty()) {
  76. GELOGE(ge::FAILED, "File path %s is invalid.", path.c_str());
  77. return;
  78. }
  79. opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
  80. GELOGI("Get opsproto so path from env : %s", path.c_str());
  81. return;
  82. }
  83. string path_base = GetSoPath();
  84. GELOGI("path_base is %s", path_base.c_str());
  85. path_base = path_base.substr(0, path_base.rfind('/'));
  86. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  87. opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
  88. }
  89. static void GetAclParams(const std::map<ge::AscendString, ge::AscendString> &parser_params, const string &key,
  90. string &value) {
  91. for (auto &ele : parser_params) {
  92. const char *key_ascend = ele.first.GetString();
  93. if (key_ascend == nullptr) {
  94. GELOGW("Input options key is null, Please check!");
  95. continue;
  96. }
  97. string key_str = key_ascend;
  98. if (key == key_str) {
  99. const char *value_ascend = ele.second.GetString();
  100. if (value_ascend == nullptr) {
  101. value = "";
  102. } else {
  103. value = value_ascend;
  104. }
  105. return;
  106. }
  107. }
  108. value = "";
  109. return;
  110. }
  111. static bool CheckDigitStr(std::string &str) {
  112. for (char c : str) {
  113. if (!isdigit(c)) {
  114. GELOGE(domi::FAILED, "Value[%s] is not positive integer", str.c_str());
  115. return false;
  116. }
  117. }
  118. return true;
  119. }
  120. } // namespace
  121. namespace ge {
  122. static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_param) {
  123. if ((s == "true") || (s == "false")) {
  124. return true;
  125. } else {
  126. ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {atc_param, s});
  127. GELOGE(PARAM_INVALID, "Input parameter[%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str());
  128. return false;
  129. }
  130. }
  131. static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) {
  132. int32_t out_size = op_desc->GetOutputsSize();
  133. if (index < 0 || index >= out_size) {
  134. GELOGE(domi::FAILED,
  135. "out_node [%s] output index:%d must be smaller "
  136. "than node output size:%d and can not be negative!",
  137. op_desc->GetName().c_str(), index, out_size);
  138. std::string fail_reason = "output index:" + to_string(index) +
  139. " must be smaller than output size:" + to_string(out_size) + " and can not be negative!";
  140. ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"},
  141. {"out_nodes", op_desc->GetName(), fail_reason});
  142. return domi::FAILED;
  143. }
  144. return domi::SUCCESS;
  145. }
  146. domi::Status AclGrphParseUtil::LoadOpsProtoLib() {
  147. string opsproto_path;
  148. GetOpsProtoPath(opsproto_path);
  149. GELOGI("Get opsproto path is %s", opsproto_path.c_str());
  150. OpsProtoManager *manager = OpsProtoManager::Instance();
  151. map<string, string> option_tmp;
  152. option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
  153. bool is_proto_init = manager->Initialize(option_tmp);
  154. if (!is_proto_init) {
  155. GELOGE(FAILED, "Load ops_proto lib failed, ops proto path is invalid.");
  156. return FAILED;
  157. }
  158. return SUCCESS;
  159. }
  160. void AclGrphParseUtil::SaveCustomCaffeProtoPath() {
  161. GELOGD("Enter save custom caffe proto path.");
  162. std::string path_base = GetSoPath();
  163. path_base = path_base.substr(0, path_base.rfind('/'));
  164. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  165. ge::GetParserContext().caffe_proto_path = path_base + "include/proto/";
  166. string custom_op_path;
  167. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  168. if (path_env != nullptr) {
  169. std::string path = path_env;
  170. custom_op_path = path + "/framework/custom/caffe/";
  171. GELOGI("Get custom proto path from env : %s", path_env);
  172. GetParserContext().custom_proto_path = custom_op_path;
  173. return;
  174. }
  175. custom_op_path = path_base + "ops/framework/custom/caffe/";
  176. ge::GetParserContext().custom_proto_path = custom_op_path;
  177. return;
  178. }
  179. // Initialize PARSER, load custom op plugin
  180. // options will be used later for parser decoupling
  181. domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, std::string> &options) {
  182. GELOGT(TRACE_INIT, "AclParserInitialize start");
  183. // check init status
  184. if (parser_initialized) {
  185. GELOGW("AclParserInitialize is called more than once");
  186. return SUCCESS;
  187. }
  188. // load custom op plugin
  189. TBEPluginLoader::Instance().LoadPluginSo(options);
  190. // load and save custom op proto for prediction
  191. (void)LoadOpsProtoLib();
  192. SaveCustomCaffeProtoPath();
  193. auto op_registry = domi::OpRegistry::Instance();
  194. if (op_registry == nullptr) {
  195. GELOGE(FAILED, "Get OpRegistry instance failed");
  196. return FAILED;
  197. }
  198. auto it = options.find(ge::FRAMEWORK_TYPE);
  199. if (it == options.end()) {
  200. GELOGE(FAILED, "Can not find ge.frameworkType in options");
  201. return FAILED;
  202. }
  203. std::string fmk_type = it->second;
  204. std::vector<OpRegistrationData> registrationDatas = op_registry->registrationDatas;
  205. GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size());
  206. for (OpRegistrationData &reg_data : registrationDatas) {
  207. if (std::to_string(reg_data.GetFrameworkType()) == fmk_type) {
  208. (void)OpRegistrationTbe::Instance()->Finalize(reg_data, false);
  209. (void)domi::OpRegistry::Instance()->Register(reg_data);
  210. }
  211. }
  212. // set init status
  213. if (!parser_initialized) {
  214. // Initialize success, first time calling initialize
  215. parser_initialized = true;
  216. }
  217. GELOGT(TRACE_STOP, "AclParserInitialize finished");
  218. return SUCCESS;
  219. }
  220. void AclGrphParseUtil::SetDefaultFormat() {
  221. if (ge::GetParserContext().type == domi::TENSORFLOW) {
  222. ge::GetParserContext().format = domi::DOMI_TENSOR_NHWC;
  223. } else {
  224. ge::GetParserContext().format = domi::DOMI_TENSOR_NCHW;
  225. }
  226. }
  227. domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) {
  228. try {
  229. // parse output node
  230. if (!out_nodes.empty()) {
  231. ge::GetParserContext().out_nodes_map.clear();
  232. ge::GetParserContext().user_out_nodes.clear();
  233. ge::GetParserContext().user_out_nodes_top_vec.clear();
  234. vector<string> nodes_v = StringUtils::Split(out_nodes, ';');
  235. for (const string &node : nodes_v) {
  236. vector<string> key_value_v = StringUtils::Split(node, ':');
  237. if (key_value_v.size() != 2) { // The size must be 2.
  238. if (key_value_v.size() == 1 && ge::GetParserContext().type == domi::CAFFE) {
  239. ge::GetParserContext().user_out_nodes_top_vec.push_back(node);
  240. continue;
  241. }
  242. ErrorManager::GetInstance().ATCReportErrMessage(
  243. "E10001", {"parameter", "value", "reason"},
  244. {"out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""});
  245. GELOGE(PARAM_INVALID,
  246. "The input format of out_nodes is invalid, the correct format is "
  247. "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.",
  248. node.c_str());
  249. return PARAM_INVALID;
  250. }
  251. if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) {
  252. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  253. {"out_nodes", out_nodes, "is not all index or top_name"});
  254. GELOGE(PARAM_INVALID, "This out_nodes str must be all index or top_name, while the actual input is %s",
  255. out_nodes.c_str());
  256. return PARAM_INVALID;
  257. }
  258. // stoi: The method may throw an exception: invalid_argument/out_of_range
  259. if (!CheckDigitStr(key_value_v[1])) {
  260. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  261. {"out_nodes", out_nodes, "is not positive integer"});
  262. GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str());
  263. return PARAM_INVALID;
  264. }
  265. auto iter = ge::GetParserContext().out_nodes_map.find(key_value_v[0]);
  266. int32_t index = stoi(StringUtils::Trim(key_value_v[1]));
  267. GELOGD("Get output info: node[%s] and index[%d]", key_value_v[0].c_str(), index);
  268. if (iter != ge::GetParserContext().out_nodes_map.end()) {
  269. iter->second.emplace_back(index);
  270. } else {
  271. std::vector<int32_t> index_v;
  272. index_v.emplace_back(index);
  273. ge::GetParserContext().out_nodes_map.emplace(key_value_v[0], index_v);
  274. }
  275. ge::GetParserContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index));
  276. }
  277. }
  278. } catch (std::invalid_argument &) {
  279. GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
  280. ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"out_nodes", out_nodes});
  281. return PARAM_INVALID;
  282. } catch (std::out_of_range &) {
  283. GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
  284. ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"out_nodes", out_nodes});
  285. return PARAM_INVALID;
  286. }
  287. return SUCCESS;
  288. }
  289. domi::Status AclGrphParseUtil::ParseAclOutputFp16NodesFormat(const string &is_output_fp16) {
  290. if (is_output_fp16.empty()) {
  291. return SUCCESS;
  292. }
  293. vector<domiTensorFormat_t> &output_formats = ge::GetParserContext().output_formats;
  294. output_formats.clear();
  295. vector<string> node_format_vec = StringUtils::Split(is_output_fp16, ',');
  296. for (auto &is_fp16 : node_format_vec) {
  297. StringUtils::Trim(is_fp16);
  298. if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) {
  299. GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]",
  300. is_output_fp16.c_str());
  301. return PARAM_INVALID;
  302. }
  303. if (is_fp16 == "false") {
  304. output_formats.push_back(DOMI_TENSOR_ND);
  305. } else if (is_fp16 == "true") {
  306. output_formats.push_back(domi::DOMI_TENSOR_NC1HWC0);
  307. }
  308. }
  309. return SUCCESS;
  310. }
  311. domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) {
  312. ge::GetParserContext().enable_scope_fusion_passes.clear();
  313. if (enable_scope_fusion_passes.empty()) {
  314. return SUCCESS;
  315. }
  316. ge::GetParserContext().enable_scope_fusion_passes = enable_scope_fusion_passes;
  317. return SUCCESS;
  318. }
  319. void AclGrphParseUtil::AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec,
  320. const string &fp16_nodes_name, uint32_t index, OpDescPtr &op_desc) {
  321. if (AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) {
  322. if ((index < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[index] == "true")) {
  323. GELOGI("This node [%s] should be set NC1HWC0", fp16_nodes_name.c_str());
  324. if (!AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_FORMAT, TypeUtils::FormatToSerialString(FORMAT_NC1HWC0))) {
  325. GELOGW("This node [%s] set NC1HWC0 failed", fp16_nodes_name.c_str());
  326. }
  327. }
  328. }
  329. }
  330. domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes,
  331. const string &is_input_adjust_hw_layout) {
  332. GE_CHECK_NOTNULL(graph);
  333. vector<string> adjust_fp16_format_vec;
  334. if (!is_input_adjust_hw_layout.empty()) {
  335. adjust_fp16_format_vec = StringUtils::Split(is_input_adjust_hw_layout, ',');
  336. for (auto &s : adjust_fp16_format_vec) {
  337. StringUtils::Trim(s);
  338. if (!CheckInputTrueOrFalse(s, "is_input_adjust_hw_layout")) {
  339. GELOGE(PARAM_INVALID, "Invalid Param, is_input_adjust_hw_layout only support true/false: but is [%s]",
  340. is_input_adjust_hw_layout.c_str());
  341. return PARAM_INVALID;
  342. }
  343. }
  344. }
  345. if (input_fp16_nodes.empty()) {
  346. return SUCCESS;
  347. }
  348. GELOGI("The input_fp16_nodes is set %s", input_fp16_nodes.c_str());
  349. vector<string> input_fp16_nodes_vec = StringUtils::Split(input_fp16_nodes, ';');
  350. for (uint32_t i = 0; i < input_fp16_nodes_vec.size(); ++i) {
  351. ge::NodePtr node = graph->FindNode(input_fp16_nodes_vec[i]);
  352. if (node == nullptr) {
  353. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  354. {"input_fp16_nodes", input_fp16_nodes_vec[i]});
  355. GELOGE(PARAM_INVALID, "Input parameter[input_fp16_nodes]'s opname[%s] is not exist in model",
  356. input_fp16_nodes_vec[i].c_str());
  357. return PARAM_INVALID;
  358. }
  359. auto op_desc = node->GetOpDesc();
  360. GE_CHECK_NOTNULL(op_desc);
  361. if (op_desc->GetType() != ge::parser::DATA) {
  362. ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"},
  363. {"input_fp16_nodes", input_fp16_nodes_vec[i]});
  364. GELOGE(PARAM_INVALID, "Input parameter[input_fp16_nodes]'s opname[%s] is not a input opname",
  365. input_fp16_nodes_vec[i].c_str());
  366. return PARAM_INVALID;
  367. }
  368. AddAttrsForInputNodes(adjust_fp16_format_vec, input_fp16_nodes_vec[i], i, op_desc);
  369. }
  370. return SUCCESS;
  371. }
  372. void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
  373. std::vector<std::string> &output_nodes_name) {
  374. output_nodes_name.clear();
  375. if (ge::GetParserContext().out_top_names.empty()) {
  376. // tf process, no top name.
  377. for (const auto output_node_info : output_nodes_info) {
  378. std::string node_name = output_node_info.first->GetName();
  379. int32_t index = output_node_info.second;
  380. output_nodes_name.push_back(node_name + ":" + std::to_string(index));
  381. }
  382. return;
  383. }
  384. // caffe process, need add top name after node_name:index
  385. for (size_t i = 0; i < output_nodes_info.size(); ++i) {
  386. std::string node_name = output_nodes_info[i].first->GetName();
  387. int32_t index = output_nodes_info[i].second;
  388. if (i < ge::GetParserContext().out_top_names.size()) {
  389. output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" +
  390. ge::GetParserContext().out_top_names[i]);
  391. } else {
  392. GELOGW("Get top name of node [%s] fail.", node_name.c_str());
  393. output_nodes_name.push_back(node_name + ":" + std::to_string(index));
  394. }
  395. }
  396. }
  397. domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node,
  398. std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
  399. ge::OpDescPtr tmpDescPtr = node->GetOpDesc();
  400. if (tmpDescPtr == nullptr) {
  401. GELOGE(domi::FAILED, "Get outnode op desc fail.");
  402. return domi::FAILED;
  403. }
  404. size_t size = tmpDescPtr->GetOutputsSize();
  405. if (node->GetType() != ge::parser::NETOUTPUT) {
  406. for (size_t index = 0; index < size; ++index) {
  407. output_nodes_info.push_back(std::make_pair(node, index));
  408. GELOGD("Get output leaf node:%s.", node->GetName().c_str());
  409. }
  410. } else {
  411. const auto in_anchors = node->GetAllInDataAnchors();
  412. for (auto in_anchor : in_anchors) {
  413. auto out_anchor = in_anchor->GetPeerOutAnchor();
  414. if (out_anchor == nullptr) {
  415. GELOGE(domi::FAILED, "Get leaf node op desc fail.");
  416. return domi::FAILED;
  417. }
  418. auto out_node = out_anchor->GetOwnerNode();
  419. output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx()));
  420. }
  421. }
  422. return SUCCESS;
  423. }
  424. domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
  425. std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
  426. std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes;
  427. if (ge::GetParserContext().type == domi::CAFFE && !default_out_nodes.empty()) {
  428. for (uint32_t i = 0; i < default_out_nodes.size(); ++i) {
  429. ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first);
  430. if (out_node == nullptr) {
  431. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  432. {"out_nodes", default_out_nodes[i].first});
  433. GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", default_out_nodes[i].first.c_str());
  434. return domi::FAILED;
  435. }
  436. output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second));
  437. GELOGD("Get default output node:%s.", out_node->GetName().c_str());
  438. }
  439. return domi::SUCCESS;
  440. }
  441. for (ge::NodePtr node : compute_graph->GetDirectNode()) {
  442. if (!node->GetInAllNodes().empty() && node->GetOutAllNodes().empty()) {
  443. Status ret = GetOutputLeaf(node, output_nodes_info);
  444. GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Find leaf fail.");
  445. }
  446. }
  447. return domi::SUCCESS;
  448. }
  449. domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph,
  450. const std::map<AscendString, AscendString> &parser_params) {
  451. ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
  452. GE_CHECK_NOTNULL(compute_graph);
  453. std::vector<std::pair<std::string, int32_t>> user_out_nodes = ge::GetParserContext().user_out_nodes;
  454. std::vector<domiTensorFormat_t> output_formats = ge::GetParserContext().output_formats;
  455. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_info;
  456. std::vector<std::string> output_nodes_name;
  457. // User declared outputs
  458. for (uint32_t i = 0; i < user_out_nodes.size(); ++i) {
  459. ge::NodePtr out_node = compute_graph->FindNode(user_out_nodes[i].first);
  460. if (out_node == nullptr) {
  461. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  462. {"out_nodes", user_out_nodes[i].first});
  463. GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str());
  464. return domi::FAILED;
  465. }
  466. auto op_desc = out_node->GetOpDesc();
  467. GE_CHECK_NOTNULL(op_desc);
  468. if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) {
  469. GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str());
  470. return domi::FAILED;
  471. }
  472. // add user_define_output_nodes attr.
  473. (void)ge::AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_OUTPUT_NODES, "true");
  474. if (i < output_formats.size()) {
  475. if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) {
  476. GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str());
  477. vector<string> output_fp16_5hd_vec;
  478. (void)ge::AttrUtils::GetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec);
  479. output_fp16_5hd_vec.push_back(std::to_string(user_out_nodes[i].second) + ":" + "NC1HWC0");
  480. (void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec);
  481. }
  482. }
  483. output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second));
  484. }
  485. // default output node (leaf)
  486. if (user_out_nodes.empty()) {
  487. if (GetDefaultOutInfo(compute_graph, output_nodes_info) != SUCCESS) {
  488. GELOGE(domi::FAILED, "Get default output info failed.");
  489. return domi::FAILED;
  490. }
  491. }
  492. GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name);
  493. compute_graph->SetGraphOutNodesInfo(output_nodes_info);
  494. ge::GetParserContext().net_out_nodes = output_nodes_name;
  495. GELOGI("Set graph %s output node success.", graph.GetName().c_str());
  496. return domi::SUCCESS;
  497. }
  498. domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) {
  499. for (auto &ele : parser_params) {
  500. const char *key_ascend = ele.first.GetString();
  501. if (key_ascend == nullptr) {
  502. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  503. {"parser_params", "null AscendString"});
  504. GELOGE(PARAM_INVALID, "Input options key is null, Please check!");
  505. return PARAM_INVALID;
  506. }
  507. string key_str = key_ascend;
  508. auto it = ge::ir_option::ir_parser_suppported_options.find(key_str);
  509. if (it == ge::ir_option::ir_parser_suppported_options.end()) {
  510. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"parser_params", key_str});
  511. GELOGE(PARAM_INVALID, "Input options include unsupported option(%s).Please check!", key_ascend);
  512. return PARAM_INVALID;
  513. }
  514. }
  515. return SUCCESS;
  516. }
  517. domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
  518. string &graph_name) {
  519. GELOGI("Parse graph user options start.");
  520. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS, return PARAM_INVALID,
  521. "Parse paragrams invalid.");
  522. // support paragrams: out_nodes, is_output_adjust_hw_layout, output, enable_scope_fusion_passes
  523. SetDefaultFormat();
  524. string out_nodes;
  525. GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes);
  526. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputNodes(out_nodes) != SUCCESS, return PARAM_INVALID,
  527. "Parse out_nodes failed");
  528. string is_output_adjust_hw_layout;
  529. GetAclParams(parser_params, ge::ir_option::IS_OUTPUT_ADJUST_HW_LAYOUT, is_output_adjust_hw_layout);
  530. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS,
  531. return PARAM_INVALID, "Parse is_output_adjust_hw_layout failed");
  532. string tmp_name;
  533. GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name);
  534. graph_name = tmp_name.empty() ? (kGraphDefaultName + "_" + ge::parser::CurrentTimeInStr()) : tmp_name;
  535. string enable_scope_fusion_passes;
  536. GetAclParams(parser_params, ge::ir_option::ENABLE_SCOPE_FUSION_PASSES, enable_scope_fusion_passes);
  537. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclEnableScope(enable_scope_fusion_passes) != SUCCESS, return PARAM_INVALID,
  538. "Parse enable_scope_fusion_passes failed");
  539. return SUCCESS;
  540. }
  541. domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,
  542. const std::map<AscendString, AscendString> &parser_params) {
  543. // support paragrams: input_fp16_nodes, is_input_adjust_hw_layout,
  544. ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
  545. GE_CHECK_NOTNULL(compute_graph);
  546. string input_fp16_nodes;
  547. GetAclParams(parser_params, ge::ir_option::INPUT_FP16_NODES, input_fp16_nodes);
  548. string is_input_adjust_hw_layout;
  549. GetAclParams(parser_params, ge::ir_option::IS_INPUT_ADJUST_HW_LAYOUT, is_input_adjust_hw_layout);
  550. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  551. ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS,
  552. return PARAM_INVALID, "Parse input_fp16_nodes failed");
  553. return SUCCESS;
  554. }
  555. namespace parser {
  556. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) {
  557. if (path == nullptr) {
  558. GELOGE(ge::FAILED, "path pointer is NULL.");
  559. return "";
  560. }
  561. if (strlen(path) >= PATH_MAX) {
  562. ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(PATH_MAX)});
  563. GELOGE(ge::FAILED, "Path[%s] len is too long, it must be less than %d", path, PATH_MAX);
  564. return "";
  565. }
  566. // Nullptr is returned when the path does not exist or there is no permission
  567. // Return absolute path when path is accessible
  568. std::string res;
  569. char resolved_path[PATH_MAX] = {0};
  570. if (realpath(path, resolved_path) != nullptr) {
  571. res = resolved_path;
  572. }
  573. return res;
  574. }
  575. // Get file length
  576. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) {
  577. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(), return -1, "input_file path is null.");
  578. std::string real_path = RealPath(input_file.c_str());
  579. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str());
  580. unsigned long long file_length = 0;
  581. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK,
  582. ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"},
  583. {input_file, strerror(errno)});
  584. return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno));
  585. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0),
  586. ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file});
  587. return -1, "File[%s] size is 0, not valid.", input_file.c_str());
  588. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit,
  589. ErrorManager::GetInstance().ATCReportErrMessage(
  590. "E19016", {"filepath", "filesize", "maxlen"},
  591. {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)});
  592. return -1, "File[%s] size %lld is out of limit: %d.",
  593. input_file.c_str(), file_length, kMaxFileSizeLimit);
  594. return static_cast<long>(file_length);
  595. }
  596. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() {
  597. struct timeval tv{};
  598. int ret = gettimeofday(&tv, nullptr);
  599. GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret);
  600. auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds
  601. return static_cast<uint64_t>(total_use_time);
  602. }
  603. static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) {
  604. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr,
  605. return false, "incorrect parameter. nullptr == proto");
  606. coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit, kWarningThreshold);
  607. return proto->ParseFromCodedStream(&coded_stream);
  608. }
  609. /** @ingroup domi_common
  610. * @brief Read all data from binary file
  611. * @param [in] file_name File path
  612. * @param [out] buffer The address of the output memory, which needs to be released by the caller
  613. * @param [out] length Output memory size
  614. * @return false fail
  615. * @return true success
  616. */
  617. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer,
  618. int &length) {
  619. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file is nullptr");
  620. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), return false, "incorrect parameter. buffer is nullptr");
  621. std::string real_path = RealPath(file_name);
  622. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "file path '%s' not valid", file_name);
  623. std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate);
  624. if (!file.is_open()) {
  625. GELOGE(ge::FAILED, "Read file %s failed.", file_name);
  626. return false;
  627. }
  628. length = static_cast<int>(file.tellg());
  629. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((length <= 0), file.close(); return false, "file length <= 0");
  630. file.seekg(0, std::ios::beg);
  631. *buffer = new(std::nothrow) char[length]();
  632. GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, file.close(), "new an object failed.");
  633. file.read(*buffer, length);
  634. file.close();
  635. return true;
  636. }
  637. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) {
  638. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr),
  639. return false,
  640. "Input parameter file or proto is nullptr!");
  641. std::string real_path = RealPath(file);
  642. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
  643. return false, "pb file path '%s' not valid", file);
  644. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");
  645. std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary);
  646. if (!fs.is_open()) {
  647. ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"});
  648. GELOGE(ge::FAILED, "Open real path[%s] failed.", file);
  649. return false;
  650. }
  651. google::protobuf::io::IstreamInputStream istream(&fs);
  652. google::protobuf::io::CodedInputStream coded_stream(&istream);
  653. bool ret = ReadProtoFromCodedInputStream(coded_stream, proto);
  654. fs.close();
  655. if (!ret) {
  656. ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file});
  657. GELOGE(ge::FAILED, "Parse file[%s] failed.", file);
  658. return ret;
  659. }
  660. return ret;
  661. }
  662. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) {
  663. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), return false,
  664. "incorrect parameter. proto is nullptr || data is nullptr || size is 0");
  665. google::protobuf::io::CodedInputStream coded_stream(reinterpret_cast<uint8_t *>(const_cast<void *>(data)), size);
  666. return ReadProtoFromCodedInputStream(coded_stream, proto);
  667. }
  668. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file,
  669. google::protobuf::Message *message) {
  670. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), return false,
  671. "incorrect parameter. nullptr == file || nullptr == message");
  672. std::string real_path = RealPath(file);
  673. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
  674. ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"},
  675. {file, strerror(errno)});
  676. return false, "Path[%s]'s realpath is empty, errmsg[%s]", file,
  677. strerror(errno));
  678. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");
  679. std::ifstream fs(real_path.c_str(), std::ifstream::in);
  680. if (!fs.is_open()) {
  681. ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file});
  682. GELOGE(ge::FAILED,
  683. "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), file);
  684. return false;
  685. }
  686. google::protobuf::io::IstreamInputStream input(&fs);
  687. bool ret = google::protobuf::TextFormat::Parse(&input, message);
  688. GE_IF_BOOL_EXEC(!ret,
  689. ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file});
  690. GELOGE(ret, "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, "
  691. "please check whether the file is a valid protobuf format file.", file));
  692. fs.close();
  693. return ret;
  694. }
  695. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size,
  696. google::protobuf::Message *message) {
  697. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), return false,
  698. "incorrect parameter. data is nullptr || message is nullptr");
  699. std::string str(data, static_cast<size_t>(size));
  700. std::istringstream fs(str);
  701. google::protobuf::io::IstreamInputStream input(&fs);
  702. bool ret = google::protobuf::TextFormat::Parse(&input, message);
  703. GE_IF_BOOL_EXEC(
  704. !ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file."));
  705. return ret;
  706. }
  707. ///
  708. /// @brief get the Original Type of FrameworkOp
  709. /// @param [in] node
  710. /// @param [out] type
  711. /// @return Status
  712. ///
  713. Status GetOriginalType(const ge::NodePtr &node, string &type) {
  714. GE_CHECK_NOTNULL(node);
  715. type = node->GetType();
  716. GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS);
  717. GE_CHECK_NOTNULL(node->GetOpDesc());
  718. bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  719. if (!ret) {
  720. GELOGE(INTERNAL_ERROR, "Get FrameWorkOp original type [%s]", type.c_str());
  721. return INTERNAL_ERROR;
  722. }
  723. GELOGD("Get FrameWorkOp original type [%s]", type.c_str());
  724. return SUCCESS;
  725. }
  726. FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) {
  727. char ebuff[kMaxBuffSize];
  728. regex_t reg;
  729. int cflags = REG_EXTENDED | REG_NOSUB;
  730. int ret = regcomp(&reg, mode.c_str(), cflags);
  731. if (ret) {
  732. regerror(ret, &reg, ebuff, kMaxBuffSize);
  733. GELOGW("regcomp failed, reason: %s", ebuff);
  734. regfree(&reg);
  735. return true;
  736. }
  737. ret = regexec(&reg, str.c_str(), 0, nullptr, 0);
  738. if (ret) {
  739. regerror(ret, &reg, ebuff, kMaxBuffSize);
  740. GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff);
  741. regfree(&reg);
  742. return false;
  743. }
  744. regfree(&reg);
  745. return true;
  746. }
  747. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() {
  748. std::time_t now = std::time(nullptr);
  749. std::tm *ptm = std::localtime(&now);
  750. if (ptm == nullptr) {
  751. GELOGE(ge::FAILED, "Localtime failed.");
  752. return "";
  753. }
  754. const int kTimeBufferLen = 32;
  755. char buffer[kTimeBufferLen + 1] = {0};
  756. // format: 20171122042550
  757. std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm);
  758. return std::string(buffer);
  759. }
  760. } // namespace parser
  761. } // namespace ge