You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

single_op_parser.cc 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "single_op_parser.h"
  17. #include <vector>
  18. #include <algorithm>
  19. #include <fstream>
  20. #include <sstream>
  21. #include <nlohmann/json.hpp>
  22. #include "framework/common/debug/ge_log.h"
  23. #include "common/util/error_manager/error_manager.h"
  24. #include "common/ge_inner_error_codes.h"
  25. #include "framework/common/util.h"
  26. #include "graph/utils/tensor_utils.h"
  27. using Json = nlohmann::json;
  28. using std::map;
  29. using std::string;
  30. using std::vector;
  31. namespace ge {
  32. namespace {
  33. constexpr char const *kKeyOp = "op";
  34. constexpr char const *kKeyInputDesc = "input_desc";
  35. constexpr char const *kKeyOutputDesc = "output_desc";
  36. constexpr char const *kKeyAttr = "attr";
  37. constexpr char const *kKeyName = "name";
  38. constexpr char const *kKeyType = "type";
  39. constexpr char const *kKeyShape = "shape";
  40. constexpr char const *kKeyValue = "value";
  41. constexpr char const *kKeyFormat = "format";
  42. constexpr char const *kFileSuffix = ".om";
  43. constexpr int kDumpJsonIndent = 2;
  44. map<string, GeAttrValue::ValueType> kAttrTypeDict = {
  45. {"bool", GeAttrValue::VT_BOOL},
  46. {"int", GeAttrValue::VT_INT},
  47. {"float", GeAttrValue::VT_FLOAT},
  48. {"string", GeAttrValue::VT_STRING},
  49. {"list_bool", GeAttrValue::VT_LIST_BOOL},
  50. {"list_int", GeAttrValue::VT_LIST_INT},
  51. {"list_float", GeAttrValue::VT_LIST_FLOAT},
  52. {"list_string", GeAttrValue::VT_LIST_STRING},
  53. {"list_list_int", GeAttrValue::VT_LIST_LIST_INT},
  54. };
  55. map<string, DataType> kDataTypeDict = {
  56. {"bool", DT_BOOL}, {"int8", DT_INT8}, {"uint8", DT_UINT8}, {"int16", DT_INT16}, {"uint16", DT_UINT16},
  57. {"int32", DT_INT32}, {"uint32", DT_UINT32}, {"int64", DT_INT64}, {"uint64", DT_UINT64}, {"float16", DT_FLOAT16},
  58. {"half", DT_FLOAT16}, {"fp16", DT_FLOAT16}, {"float", DT_FLOAT}, {"float32", DT_FLOAT}, {"double", DT_DOUBLE},
  59. };
  60. map<string, Format> kFormatDict = {
  61. {"nchw", FORMAT_NCHW}, {"nhwc", FORMAT_NHWC}, {"nd", FORMAT_ND}, {"fractal_nz", FORMAT_FRACTAL_NZ},
  62. {"fractal_z", FORMAT_FRACTAL_Z}, {"nc1hwc0", FORMAT_NC1HWC0},
  63. };
  64. } // namespace
  65. template <typename T>
  66. void SetAttrValue(const Json &j, SingleOpAttr &attr) {
  67. attr.value.SetValue<T>(j.at(kKeyValue).get<T>());
  68. }
  69. template <typename T>
  70. T GetValue(const map<string, T> &dict, string &key, T default_val) {
  71. transform(key.begin(), key.end(), key.begin(), ::tolower);
  72. auto it = dict.find(key);
  73. if (it == dict.end()) {
  74. return default_val;
  75. }
  76. return it->second;
  77. }
  78. void from_json(const Json &j, SingleOpTensorDesc &desc) {
  79. desc.dims = j.at(kKeyShape).get<vector<int64_t>>();
  80. string format_str = j.at(kKeyFormat).get<string>();
  81. string type_str = j.at(kKeyType).get<string>();
  82. desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED);
  83. desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED);
  84. auto tensor_name = j.find(kKeyName);
  85. if (tensor_name != j.end()) {
  86. desc.name = tensor_name->get<string>();
  87. }
  88. }
  89. void from_json(const Json &j, SingleOpAttr &attr) {
  90. attr.name = j.at(kKeyName).get<string>();
  91. attr.type = j.at(kKeyType).get<string>();
  92. auto it = kAttrTypeDict.find(attr.type);
  93. if (it == kAttrTypeDict.end()) {
  94. GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str());
  95. return;
  96. }
  97. switch (it->second) {
  98. case GeAttrValue::VT_BOOL:
  99. SetAttrValue<bool>(j, attr);
  100. break;
  101. case GeAttrValue::VT_INT:
  102. SetAttrValue<int64_t>(j, attr);
  103. break;
  104. case GeAttrValue::VT_FLOAT:
  105. SetAttrValue<float>(j, attr);
  106. break;
  107. case GeAttrValue::VT_STRING:
  108. SetAttrValue<string>(j, attr);
  109. break;
  110. case GeAttrValue::VT_LIST_BOOL:
  111. SetAttrValue<vector<bool>>(j, attr);
  112. break;
  113. case GeAttrValue::VT_LIST_INT:
  114. SetAttrValue<vector<int64_t>>(j, attr);
  115. break;
  116. case GeAttrValue::VT_LIST_FLOAT:
  117. SetAttrValue<vector<float>>(j, attr);
  118. break;
  119. case GeAttrValue::VT_LIST_STRING:
  120. SetAttrValue<vector<string>>(j, attr);
  121. break;
  122. case GeAttrValue::VT_LIST_LIST_INT:
  123. SetAttrValue<vector<vector<int64_t>>>(j, attr);
  124. break;
  125. default:
  126. GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str());
  127. break;
  128. }
  129. }
  130. void from_json(const Json &j, SingleOpDesc &desc) {
  131. desc.op = j.at(kKeyOp).get<string>();
  132. auto input_desc = j.find(kKeyInputDesc);
  133. if (input_desc != j.end()) {
  134. desc.input_desc = input_desc->get<vector<SingleOpTensorDesc>>();
  135. }
  136. auto output_desc = j.find(kKeyOutputDesc);
  137. if (output_desc != j.end()) {
  138. desc.output_desc = output_desc->get<vector<SingleOpTensorDesc>>();
  139. }
  140. auto attr_field = j.find(kKeyAttr);
  141. if (attr_field != j.end()) {
  142. desc.attrs = attr_field->get<vector<SingleOpAttr>>();
  143. }
  144. }
  145. Status SingleOpParser::ReadJsonFile(const std::string &file, Json &json_obj) {
  146. std::string real_path = RealPath(file.c_str());
  147. if (real_path.empty()) {
  148. ErrorManager::GetInstance().ATCReportErrMessage("E10023", {"value"}, {file});
  149. GELOGE(FAILED, "Input parameter[--singleop]'s value[%s] is not a valid path.", file.c_str());
  150. return INTERNAL_ERROR;
  151. }
  152. std::ifstream ifs(real_path);
  153. if (!ifs.is_open()) {
  154. ErrorManager::GetInstance().ATCReportErrMessage("E10024", {"value"}, {file});
  155. GELOGE(FAILED, "Open file[%s] provided in input parameter[--singleop] failed.", file.c_str());
  156. return FAILED;
  157. }
  158. try {
  159. ifs >> json_obj;
  160. } catch (const std::exception &e) {
  161. ErrorManager::GetInstance().ATCReportErrMessage("E10025", {"realpath", "errmsg"}, {real_path, e.what()});
  162. GELOGE(PARAM_INVALID, "Parse file[%s] provided in input parameter[--singleop] failed, exception = %s.",
  163. real_path.c_str(), e.what());
  164. return PARAM_INVALID;
  165. }
  166. ifs.close();
  167. return SUCCESS;
  168. }
  169. bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {
  170. if (op_desc.op.empty()) {
  171. ErrorManager::GetInstance().ATCReportErrMessage("E10026");
  172. GELOGE(PARAM_INVALID, "Op name is empty");
  173. return false;
  174. }
  175. int index = 0;
  176. for (auto &tensor_desc : op_desc.input_desc) {
  177. if (tensor_desc.type == DT_UNDEFINED) {
  178. ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)});
  179. GELOGE(false, "Input index[%d]'s dataType is invalid", index);
  180. return false;
  181. }
  182. if (tensor_desc.format == FORMAT_RESERVED) {
  183. ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)});
  184. GELOGE(PARAM_INVALID, "Input index[%d]'s format is invalid", index);
  185. return false;
  186. }
  187. ++index;
  188. }
  189. index = 0;
  190. for (auto &tensor_desc : op_desc.output_desc) {
  191. if (tensor_desc.type == DT_UNDEFINED) {
  192. ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)});
  193. GELOGE(PARAM_INVALID, "Output[%d] dataType is invalid", index);
  194. return false;
  195. }
  196. if (tensor_desc.format == FORMAT_RESERVED) {
  197. ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"output", std::to_string(index)});
  198. GELOGE(PARAM_INVALID, "Output[%d] format is invalid", index);
  199. return false;
  200. }
  201. ++index;
  202. }
  203. for (auto &attr : op_desc.attrs) {
  204. if (attr.name.empty()) {
  205. ErrorManager::GetInstance().ATCReportErrMessage("E10029");
  206. GELOGE(PARAM_INVALID, "attr name is empty");
  207. return false;
  208. }
  209. if (attr.value.IsEmpty()) {
  210. ErrorManager::GetInstance().ATCReportErrMessage("E10030", {"attrname"}, {attr.name});
  211. GELOGE(PARAM_INVALID, "Parse attr \"%s\" failed. ", attr.name.c_str());
  212. return false;
  213. }
  214. }
  215. return true;
  216. }
  217. OpDesc *SingleOpParser::CreateOpDesc(const string &op_type) { return new (std::nothrow) OpDesc(op_type, op_type); }
  218. Status SingleOpParser::ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc,
  219. SingleOpBuildParam &build_param) {
  220. auto *op_desc = CreateOpDesc(single_op_desc.op);
  221. if (op_desc == nullptr) {
  222. GELOGE(MEMALLOC_FAILED, "Failed to create instance of opDesc");
  223. return MEMALLOC_FAILED;
  224. }
  225. std::stringstream file_name;
  226. file_name << index;
  227. file_name << "_" << single_op_desc.op;
  228. for (auto &desc : single_op_desc.input_desc) {
  229. file_name << "_" << desc.type << "_" << desc.format;
  230. for (auto dim : desc.dims) {
  231. file_name << "_" << dim;
  232. }
  233. GeTensorDesc ge_tensor_desc(GeShape(desc.dims), desc.format, desc.type);
  234. ge_tensor_desc.SetOriginFormat(desc.format);
  235. TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size());
  236. TensorUtils::SetInputTensor(ge_tensor_desc, true);
  237. TensorUtils::SetOutputTensor(ge_tensor_desc, false);
  238. if (desc.name.empty()) {
  239. op_desc->AddInputDesc(ge_tensor_desc);
  240. } else {
  241. op_desc->AddInputDesc(desc.name, ge_tensor_desc);
  242. }
  243. build_param.inputs.emplace_back(ge_tensor_desc);
  244. }
  245. for (auto &desc : single_op_desc.output_desc) {
  246. file_name << "_" << desc.type << "_" << desc.format;
  247. for (auto dim : desc.dims) {
  248. file_name << "_" << dim;
  249. }
  250. GeTensorDesc ge_tensor_desc(GeShape(desc.dims), desc.format, desc.type);
  251. ge_tensor_desc.SetOriginFormat(desc.format);
  252. TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size());
  253. TensorUtils::SetInputTensor(ge_tensor_desc, false);
  254. TensorUtils::SetOutputTensor(ge_tensor_desc, true);
  255. op_desc->AddOutputDesc(ge_tensor_desc);
  256. build_param.outputs.emplace_back(ge_tensor_desc);
  257. }
  258. for (const auto &attr : single_op_desc.attrs) {
  259. op_desc->SetAttr(attr.name, attr.value);
  260. }
  261. file_name << kFileSuffix;
  262. build_param.file_name = file_name.str();
  263. build_param.op_desc.reset(op_desc);
  264. return SUCCESS;
  265. }
  266. Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector<SingleOpBuildParam> &op_list) {
  267. Json single_op_list_json;
  268. auto ret = ReadJsonFile(file, single_op_list_json);
  269. if (ret != SUCCESS) {
  270. return ret;
  271. }
  272. int index = 0;
  273. for (const Json &single_op_json : single_op_list_json) {
  274. GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str());
  275. SingleOpDesc single_op_desc;
  276. try {
  277. single_op_desc = single_op_json;
  278. } catch (const nlohmann::json::exception &e) {
  279. ErrorManager::GetInstance().ATCReportErrMessage(
  280. "E10045", {"index", "jsonfile", "exception", "jsonStr"},
  281. {std::to_string(index), file, e.what(), single_op_json.dump(kDumpJsonIndent)});
  282. GELOGE(PARAM_INVALID, "Parse op[%d] failed when read json file[%s], exception[%s], jsonStr[%s]", index,
  283. file.c_str(), e.what(), single_op_json.dump(kDumpJsonIndent).c_str());
  284. return PARAM_INVALID;
  285. }
  286. if (!Validate(single_op_desc)) {
  287. ErrorManager::GetInstance().ATCReportErrMessage("E10046", {"index", "jsonfile"}, {std::to_string(index), file});
  288. GELOGE(PARAM_INVALID, "Validate op[%d] failed when read json file[%s].", index, file.c_str());
  289. return PARAM_INVALID;
  290. }
  291. SingleOpBuildParam param;
  292. ret = ConvertToBuildParam(index, single_op_desc, param);
  293. if (ret != SUCCESS) {
  294. return ret;
  295. }
  296. op_list.emplace_back(param);
  297. GELOGI("Parse op[%d] success", index);
  298. index += 1;
  299. }
  300. return SUCCESS;
  301. }
  302. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知.