You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_generator.cc 50 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "framework/generator/ge_generator.h"
  17. #include <atomic>
  18. #include "common/ge/ge_util.h"
  19. #include "common/ge/plugin_manager.h"
  20. #include "framework/common/helper/model_helper.h"
  21. #include "framework/common/helper/om_file_helper.h"
  22. #include "framework/common/util.h"
  23. #include "common/util/error_manager/error_manager.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "framework/common/debug/log.h"
  26. #include "external/ge/ge_api.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "graph/ge_context.h"
  29. #include "graph/manager/graph_manager.h"
  30. #include "graph/manager/graph_var_manager.h"
  31. #include "graph/manager/util/rt_context_util.h"
  32. #include "graph/operator_factory_impl.h"
  33. #include "graph/opsproto_manager.h"
  34. #include "graph/utils/graph_utils.h"
  35. #include "graph/utils/type_utils.h"
  36. #include "init/gelib.h"
  37. #include "common/model/ge_model.h"
  38. #include "analyzer/analyzer.h"
  39. using std::map;
  40. using std::string;
  41. using std::vector;
  42. namespace {
  43. const char *const kAttrOpType = "op_type";
  44. const char *const kEngineNameDefault = "default";
  45. const char *const kVectorEngine = "VectorEngine";
  46. const char *const kAIcoreEngine = "AIcoreEngine";
  47. const char *const kFileNameSuffix = "online";
  48. const char *const kAicpuAllshape = "_AllShape";
  49. constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape";
  50. const int64_t kDynamicDimValue = -2;
  51. const int kDefaultDeviceId = 0;
  52. const int kDefaultJobId = 0;
  53. const int32_t kFuzzBuildPattern = 1;
  54. const uint64_t kInitialSessionId = INT32_MAX;
  55. std::map<ge::OpEngineType, std::string> engine_type_map{
  56. {ge::ENGINE_SYS, kEngineNameDefault},
  57. {ge::ENGINE_AICORE, kAIcoreEngine},
  58. {ge::ENGINE_VECTOR, kVectorEngine}};
  59. bool ContainsDynamicInpus(const ge::OpDesc &op_desc) {
  60. for (auto &tensor_desc : op_desc.GetAllInputsDescPtr()) {
  61. if (tensor_desc->MutableShape().IsUnknownShape()) {
  62. GELOGI("Contains unknown shape input. set is_dynamic_input to true.");
  63. return true;
  64. }
  65. }
  66. return false;
  67. }
  68. // if optional in/out, format is format_reserved and dtype is dt_undefined
  69. bool IsOptional(const ge::GeTensorDesc &tensor_desc) {
  70. return tensor_desc.GetFormat() == ge::FORMAT_RESERVED && tensor_desc.GetDataType() == ge::DT_UNDEFINED;
  71. }
  72. } // namespace
  73. namespace ge {
  74. static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_type) {
  75. const OpDescPtr &op_desc = node->GetOpDesc();
  76. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  77. if (engine_type == ENGINE_SYS) {
  78. GELOGI("CheckEngineType: use default engine.");
  79. return SUCCESS;
  80. }
  81. // get op engine name
  82. string op_engine_name;
  83. auto iter = engine_type_map.find(engine_type);
  84. if (iter != engine_type_map.end()) {
  85. op_engine_name = iter->second;
  86. GELOGI("CheckEngineType: engine type: %d", static_cast<int>(engine_type));
  87. } else {
  88. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  89. {op_desc->GetName(), op_desc->GetType(), "engine type",
  90. "it only support default/AIcoreEngine/VectorEngine"});
  91. GELOGE(FAILED, "[Check][Param] value:%d not support, "
  92. "only support default/AIcoreEngine/VectorEngine now", static_cast<int>(engine_type));
  93. return FAILED;
  94. }
  95. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  96. op_desc->SetOpEngineName(op_engine_name);
  97. op_desc->SetOpKernelLibName(op_engine_name);
  98. return SUCCESS;
  99. }
  100. // set op engine name and opkernelLib. when engine support
  101. std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
  102. if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
  103. REPORT_INNER_ERROR("E19999", "get gelib failed, as get instance failed or initflag failed.");
  104. GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Get][GELib] CheckEngineType failed, as get gelib failed.");
  105. return FAILED;
  106. }
  107. OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
  108. std::vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  109. if (op_infos.empty()) {
  110. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  111. {op_desc->GetName(), op_desc->GetType(), "optype", "it can not find"});
  112. GELOGE(FAILED, "[Get][OpInfo] by op type %s failed.", op_desc->GetType().c_str());
  113. return FAILED;
  114. }
  115. string kernel_name;
  116. for (const auto &it : op_infos) {
  117. if (it.engine == op_engine_name) {
  118. kernel_name = it.opKernelLib;
  119. break;
  120. }
  121. }
  122. if (kernel_name.empty()) {
  123. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  124. {op_desc->GetName(), op_desc->GetType(), "engine name" + FmtToStr(op_engine_name), "it can not find"});
  125. GELOGE(FAILED, "[Check][Param] Can not find ops kernel, engine name:%s. op:%s(%s)",
  126. op_engine_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str());
  127. return FAILED;
  128. }
  129. auto &kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  130. auto kernel_info_store = kernel_map.find(kernel_name);
  131. if (kernel_info_store != kernel_map.end()) {
  132. std::string unsupported_reason;
  133. if (kernel_info_store->second->CheckSupported(node, unsupported_reason)) {
  134. op_desc->SetOpEngineName(op_engine_name);
  135. op_desc->SetOpKernelLibName(kernel_name);
  136. GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(),
  137. op_engine_name.c_str(), op_desc->GetName().c_str());
  138. return SUCCESS;
  139. } else {
  140. ErrorManager::GetInstance().ATCReportErrMessage(
  141. "E13002", {"optype", "opskernel", "reason"}, {op_desc->GetType(), kernel_name, unsupported_reason});
  142. GELOGE(FAILED, "[Call][CheckSupported] failed, Op type %s of ops kernel %s is unsupported, reason:%s",
  143. op_desc->GetType().c_str(), kernel_name.c_str(), unsupported_reason.c_str());
  144. return FAILED;
  145. }
  146. } else {
  147. ErrorManager::GetInstance().ATCReportErrMessage(
  148. "E13003", {"opname", "optype"}, {op_desc->GetName(), op_desc->GetType()});
  149. GELOGE(FAILED, "[Check][Param] Can not find any supported ops kernel info store by kernel_name %s,"
  150. "op type is %s, op name is %s",
  151. kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str());
  152. }
  153. return FAILED;
  154. }
  155. static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index,
  156. bool attr, int32_t &data_index) {
  157. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  158. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  159. auto format = tensor.GetFormat();
  160. auto data_type = tensor.GetDataType();
  161. if (format == FORMAT_RESERVED && data_type == DT_UNDEFINED) {
  162. return SUCCESS;
  163. }
  164. string op_type;
  165. bool is_const = false;
  166. (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const);
  167. if (is_const) {
  168. GELOGD("Get input[%d] is const", index);
  169. op_type = CONSTANTOP;
  170. } else if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
  171. op_type = DATA;
  172. }
  173. string op_name = node->GetName() + "_in_" + std::to_string(index);
  174. OpDescPtr data_op = MakeShared<ge::OpDesc>(op_name, op_type);
  175. if (data_op == nullptr) {
  176. REPORT_CALL_ERROR("E19999", "create OpDesc failed, name:%s", op_name.c_str());
  177. GELOGE(FAILED, "[Create][OpDesc] failed, name:%s", op_name.c_str());
  178. return FAILED;
  179. }
  180. if (is_const) {
  181. ConstGeTensorPtr tensor_value;
  182. if (!AttrUtils::GetTensor(tensor, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  183. REPORT_CALL_ERROR("E19999", "get attr %s failed, tensor:%s.",
  184. ge::ATTR_NAME_WEIGHTS.c_str(), tensor.GetName().c_str());
  185. GELOGE(FAILED, "[Get][Attr] %s failed, tensor:%s.", ge::ATTR_NAME_WEIGHTS.c_str(), tensor.GetName().c_str());
  186. return FAILED;
  187. }
  188. if (!AttrUtils::SetTensor(data_op, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  189. REPORT_CALL_ERROR("E19999", "set attr %s failed, op:%s.", ge::ATTR_NAME_WEIGHTS.c_str(), op_name.c_str());
  190. GELOGE(FAILED, "[Set][Attr] %s failed, op:%s.", ge::ATTR_NAME_WEIGHTS.c_str(), op_name.c_str());
  191. return FAILED;
  192. }
  193. }
  194. (void)AttrUtils::SetBool(data_op, "_is_single_op", true);
  195. (void)AttrUtils::SetBool(data_op, ATTR_NAME_IS_ORIGINAL_INPUT, true);
  196. GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS,
  197. REPORT_CALL_ERROR("E19999", "AddInputDesc failed for node:%s", data_op->GetName().c_str());
  198. return FAILED, "[Add][InputDesc] fail for node:%s", data_op->GetName().c_str());
  199. GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS,
  200. REPORT_CALL_ERROR("E19999", "AddOutputDesc failed for node:%s", data_op->GetName().c_str());
  201. return FAILED, "[Add][OutputDesc] fail for node:%s", data_op->GetName().c_str());
  202. if (attr && !is_const) {
  203. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, data_index),
  204. REPORT_CALL_ERROR("E19999", "set attr %s failed for node:%s",
  205. ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str());
  206. return FAILED,
  207. "[Set][Attr:%s] fail for node:%s", ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str());
  208. ++data_index;
  209. }
  210. ge::NodePtr arg_node = graph->AddNode(data_op);
  211. GE_CHK_BOOL_EXEC(arg_node != nullptr,
  212. REPORT_CALL_ERROR("E19999", "add node:%s to graph:%s failed", data_op->GetName().c_str(),
  213. graph->GetName().c_str());
  214. return FAILED, "[Add][Node] Insert Data node:%s fail", data_op->GetName().c_str());
  215. GE_CHK_STATUS(GraphUtils::AddEdge(arg_node->GetOutDataAnchor(0), node->GetInDataAnchor(index)),
  216. "[Add][Edge]fail from node:%s to node:%s", data_op->GetName().c_str(), node->GetName().c_str());
  217. return SUCCESS;
  218. }
  219. static Status AddOutputs(const ComputeGraphPtr &graph, const NodePtr &node, const vector<GeTensor> &outputs) {
  220. OpDescPtr op_desc = MakeShared<ge::OpDesc>(graph->GetName() + "_" + NODE_NAME_NET_OUTPUT, NETOUTPUT);
  221. if (op_desc == nullptr) {
  222. REPORT_CALL_ERROR("E19999", "create OpDesc failed, graph:%s", graph->GetName().c_str());
  223. GELOGE(FAILED, "[Create][OpDesc] failed, graph:%s", graph->GetName().c_str());
  224. return FAILED;
  225. }
  226. (void)AttrUtils::SetBool(op_desc, "_is_single_op", true);
  227. int32_t count = 0;
  228. for (const auto &out_desc : outputs) {
  229. GeTensorDesc tensor = out_desc.GetTensorDesc();
  230. TensorUtils::SetInputTensor(tensor, true);
  231. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(tensor) == GRAPH_SUCCESS,
  232. REPORT_CALL_ERROR("E19999", "AddInputDesc failed for node:%s", op_desc->GetName().c_str());
  233. return FAILED, "[Add][InputDesc]fail for node:%s", op_desc->GetName().c_str());
  234. TensorUtils::SetInputTensor(tensor, false);
  235. TensorUtils::SetOutputTensor(tensor, true);
  236. GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(tensor) == GRAPH_SUCCESS,
  237. REPORT_CALL_ERROR("E19999", "AddOutputDesc failed for node:%s", op_desc->GetName().c_str());
  238. return FAILED, "[Add][OutputDesc]fail for node:%s", op_desc->GetName().c_str());
  239. count++;
  240. }
  241. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  242. ge::NodePtr out_node = graph->AddNode(op_desc);
  243. GE_CHK_BOOL_EXEC(out_node != nullptr,
  244. REPORT_CALL_ERROR("E19999", "add node:%s to graph:%u failed.",
  245. op_desc->GetName().c_str(), graph->GetGraphID());
  246. return FAILED,
  247. "[Add][Node:%s]fail in graph:%u", op_desc->GetName().c_str(), graph->GetGraphID());
  248. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  249. for (int32_t i = 0; i < count; ++i) {
  250. GE_CHK_STATUS(GraphUtils::AddEdge(node->GetOutDataAnchor(i), out_node->GetInDataAnchor(i)),
  251. "[Add][Edge]fail from node:%s to node:%s", node->GetName().c_str(), out_node->GetName().c_str());
  252. }
  253. return SUCCESS;
  254. }
  255. static void GetOpsProtoPath(string &opsproto_path) {
  256. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  257. if (path_env != nullptr) {
  258. string path = path_env;
  259. string file_path = RealPath(path.c_str());
  260. if (file_path.empty()) {
  261. REPORT_CALL_ERROR("E19999", "File path %s is invalid.", path.c_str());
  262. GELOGE(FAILED, "[Call][RealPath] File path %s is invalid.", path.c_str());
  263. return;
  264. }
  265. opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
  266. GELOGI("Get opsproto so path from env : %s", path.c_str());
  267. return;
  268. }
  269. string path_base = PluginManager::GetPath();
  270. GELOGI("path_base is %s", path_base.c_str());
  271. path_base = path_base.substr(0, path_base.rfind('/'));
  272. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  273. opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
  274. }
  275. static Status ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) {
  276. for (auto input : inputs) {
  277. auto input_desc = input.GetTensorDesc();
  278. GeShape shape_ori = input_desc.GetShape();
  279. std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue};
  280. GeShape dynamic_shape(dynamic_shape_dims);
  281. std::vector<std::pair<int64_t, int64_t>> dynamic_shape_range;
  282. ge::GeTensor inputTensor;
  283. ge::GeTensorDesc desc(input_desc);
  284. bool is_const = false;
  285. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  286. if (!is_const) {
  287. int64_t storage_format = FORMAT_NCHW;
  288. if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_STORAGE_FORMAT, storage_format) &&
  289. !ge::AttrUtils::SetListInt(desc, ge::ATTR_NAME_STORAGE_SHAPE, dynamic_shape_dims)) {
  290. REPORT_CALL_ERROR("E19999", "Set attr ATTR_NAME_STORAGE_SHAPE failed to op:%s.", desc.GetName().c_str());
  291. GELOGE(FAILED, "[Set][Attr] ATTR_NAME_STORAGE_SHAPE fail.");
  292. return FAILED;
  293. }
  294. desc.SetShape(dynamic_shape);
  295. desc.SetShapeRange(dynamic_shape_range);
  296. }
  297. inputTensor.SetTensorDesc(desc);
  298. inputs_dynamic.push_back(inputTensor);
  299. }
  300. return SUCCESS;
  301. }
  302. static Status GetFuzzBuildAttrs(const OpDescPtr &op_desc, const GeRootModelPtr &ge_root_model,
  303. GeAttrValue::LIST_NAMED_ATTRS &fuzz_build_attrs) {
  304. GELOGD("Start get fuzz build attrs of %s.", op_desc->GetName().c_str());
  305. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  306. for (const auto &node : ge_root_model->GetRootGraph()->GetAllNodes()) {
  307. GE_CHECK_NOTNULL(node);
  308. GE_CHECK_NOTNULL(node->GetOpDesc());
  309. GELOGD("Delete fuzz build attr of %s after build.", node->GetName().c_str());
  310. node->GetOpDesc()->DelAttr(ATTR_NAME_FUZZ_BUILD);
  311. }
  312. (void)AttrUtils::GetListNamedAttrs(op_desc, ATTR_NAME_FUZZ_BUILD_RES_ATTRS, fuzz_build_attrs);
  313. if (!fuzz_build_attrs.empty()) {
  314. GELOGD("%s has split, get ATTR_NAME_FUZZ_BUILD_RES_ATTRS directly.", op_desc->GetName().c_str());
  315. return SUCCESS;
  316. } else {
  317. GELOGW("%s build with fuzz build pattern, but not set ATTR_NAME_FUZZ_BUILD_RES_ATTRS.", op_desc->GetName().c_str());
  318. }
  319. return SUCCESS;
  320. }
  321. static bool HasShapeRange(const vector<GeTensor> &inputs) {
  322. for (const auto &input : inputs) {
  323. vector<pair<int64_t, int64_t>> shape_range;
  324. (void)input.GetTensorDesc().GetShapeRange(shape_range);
  325. if (!shape_range.empty()) {
  326. GELOGD("Has set shape range.");
  327. return true;
  328. }
  329. }
  330. return false;
  331. }
  332. class GeGenerator::Impl {
  333. public:
  334. Impl(OmgContext &omg_context) : omg_context_(omg_context) {}
  335. ~Impl() = default;
  336. Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GeRootModelPtr &ge_models);
  337. Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model);
  338. Status SaveRootModel(const string &file_name_prefix, GeRootModelPtr &model, ModelBufferData &model_buff);
  339. Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  340. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
  341. Status GenerateInfershapeGraph(const Graph &graph);
  342. OmgContext &omg_context_;
  343. GraphManager graph_manager_;
  344. SaveParam save_param_;
  345. bool is_offline_ = true;
  346. bool is_singleop_unregistered_ = false;
  347. std::string build_mode_;
  348. std::string build_step_;
  349. static std::mutex mutex_;
  350. private:
  351. static std::string Trim(const std::string &str);
  352. bool ParseVersion(const std::string &line, std::string &version);
  353. bool GetVersionFromPath(const std::string &file_path, std::string &version);
  354. bool SetAtcVersionInfo(AttrHolder &obj);
  355. bool SetOppVersionInfo(AttrHolder &obj);
  356. bool SetOmSystemInfo(AttrHolder &obj);
  357. };
  358. Status GeGenerator::Initialize(const map<string, string> &options) {
  359. return Initialize(options, domi::GetContext());
  360. }
  361. Status GeGenerator::Initialize(const map<string, string> &options, OmgContext &omg_context) {
  362. impl_ = ge::MakeShared<Impl>(omg_context);
  363. if (impl_ == nullptr) {
  364. REPORT_CALL_ERROR("E19999", "create Impl failed.");
  365. GELOGE(MEMALLOC_FAILED, "[Create][Impl] Make shared failed");
  366. return MEMALLOC_FAILED;
  367. }
  368. ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOpsProtoInit);
  369. string opsproto_path;
  370. GetOpsProtoPath(opsproto_path);
  371. GELOGI("Get opsproto path is %s", opsproto_path.c_str());
  372. OpsProtoManager *manager = OpsProtoManager::Instance();
  373. map<string, string> option_tmp;
  374. option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
  375. (void)manager->Initialize(option_tmp);
  376. Status ret = impl_->graph_manager_.Initialize(options);
  377. if (ret != SUCCESS) {
  378. GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "[Call][Initialize] Graph manager initialize failed.");
  379. return GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED;
  380. }
  381. // get ek file
  382. auto iter = options.find(EK_FILE);
  383. if (iter != options.end()) {
  384. impl_->save_param_.ek_file = iter->second;
  385. }
  386. // get cert file
  387. iter = options.find(CERT_FILE);
  388. if (iter != options.end()) {
  389. impl_->save_param_.cert_file = iter->second;
  390. }
  391. // get hw key file
  392. iter = options.find(HW_KEY_FILE);
  393. if (iter != options.end()) {
  394. impl_->save_param_.hw_key_file = iter->second;
  395. }
  396. // get private file
  397. iter = options.find(PRIVATE_KEY_FILE);
  398. if (iter != options.end()) {
  399. impl_->save_param_.pri_key_file = iter->second;
  400. }
  401. // get build mode
  402. iter = options.find(BUILD_MODE);
  403. if (iter != options.end()) {
  404. impl_->build_mode_ = iter->second;
  405. }
  406. // get build step
  407. iter = options.find(BUILD_STEP);
  408. if (iter != options.end()) {
  409. impl_->build_step_ = iter->second;
  410. }
  411. return SUCCESS;
  412. }
  413. Status GeGenerator::Finalize() {
  414. ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize);
  415. if (impl_ == nullptr) {
  416. return SUCCESS;
  417. }
  418. Status ret = impl_->graph_manager_.Finalize();
  419. if (ret != SUCCESS) {
  420. GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "[Call][Finalize] Graph manager finalize failed.");
  421. return GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED;
  422. }
  423. return SUCCESS;
  424. }
  425. Status GeGenerator::GenerateOfflineModel(const Graph &graph, const string &file_name_prefix,
  426. const vector<GeTensor> &inputs) {
  427. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  428. GELOGI("Start to generate offline model.");
  429. ModelBufferData model;
  430. return GenerateModel(graph, file_name_prefix, inputs, model, true);
  431. }
  432. Status GeGenerator::GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ModelBufferData &model) {
  433. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  434. return GenerateModel(graph, "online", inputs, model, false);
  435. }
  436. Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) {
  437. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  438. Status ret = impl_->GenerateInfershapeGraph(graph);
  439. if (ret != SUCCESS) {
  440. GELOGE(ret, "[Call][GenerateInfershapeGraph] Dump infershape json failed");
  441. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  442. GELOGE(FAILED, "[Call][Finalize] graph_manager finalize fail.");
  443. }
  444. return ret;
  445. }
  446. GELOGI("Generate infer shape graph success");
  447. return SUCCESS;
  448. }
  449. std::mutex GeGenerator::Impl::mutex_;
  450. // Remove the space and tab before and after the string
  451. std::string GeGenerator::Impl::Trim(const std::string &str) {
  452. if (str.empty()) {
  453. return str;
  454. }
  455. std::string::size_type start = str.find_first_not_of(" \t\r\n");
  456. if (start == std::string::npos) {
  457. return str;
  458. }
  459. std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1;
  460. return str.substr(start, end);
  461. }
  462. // Parsing the command line
  463. bool GeGenerator::Impl::ParseVersion(const std::string &line, std::string &version) {
  464. std::string flag = "Version=";
  465. std::string temp = Trim(line);
  466. if (temp.empty()) {
  467. GELOGW("line is empty.");
  468. return false;
  469. }
  470. std::string::size_type pos = temp.find(flag);
  471. if (pos == std::string::npos) {
  472. GELOGW("Incorrect line [%s], it must include [%s].", line.c_str(), flag.c_str());
  473. return false;
  474. }
  475. if (temp.size() == flag.size()) {
  476. GELOGW("version information is empty. %s", line.c_str());
  477. return false;
  478. }
  479. version = temp.substr(pos + flag.size());
  480. return true;
  481. }
  482. bool GeGenerator::Impl::GetVersionFromPath(const std::string &file_path, std::string &version) {
  483. // Normalize the path
  484. string resolved_file_path = RealPath(file_path.c_str());
  485. if (resolved_file_path.empty()) {
  486. GELOGW("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str());
  487. return false;
  488. }
  489. std::ifstream fs(resolved_file_path, std::ifstream::in);
  490. if (!fs.is_open()) {
  491. GELOGW("Open %s failed.", file_path.c_str());
  492. return false;
  493. }
  494. std::string line;
  495. if (getline(fs, line)) {
  496. if (!ParseVersion(line, version)) {
  497. GELOGW("Parse version failed. content is [%s].", line.c_str());
  498. fs.close();
  499. return false;
  500. }
  501. } else {
  502. GELOGW("No version information found in the file path:%s", file_path.c_str());
  503. fs.close();
  504. return false;
  505. }
  506. fs.close(); // close the file
  507. return true;
  508. }
  509. // Set package version information in the model
  510. bool GeGenerator::Impl::SetAtcVersionInfo(AttrHolder &obj) {
  511. std::string path_base = ge::GELib::GetPath();
  512. path_base = path_base.substr(0, path_base.rfind('/'));
  513. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  514. std::string version_path = path_base + "version.info";
  515. std::string version;
  516. if (!GetVersionFromPath(version_path, version)) {
  517. GELOGW("Get atc version information failed!");
  518. return false;
  519. }
  520. // set version info
  521. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_ATC_VERSION, version)) {
  522. GELOGW("Ge model set atc version failed!");
  523. return false;
  524. }
  525. return true;
  526. }
  527. // Set package version information in the model
  528. bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) {
  529. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  530. if (path_env == nullptr) {
  531. GELOGW("Get environment variable ASCEND_OPP_PATH failed!");
  532. return false;
  533. }
  534. std::string version_path = path_env;
  535. version_path += "/version.info";
  536. std::string version;
  537. if (!GetVersionFromPath(version_path, version)) {
  538. GELOGW("Get opp version information failed!");
  539. return false;
  540. }
  541. // set version info
  542. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_OPP_VERSION, version)) {
  543. GELOGW("Ge model set opp version failed!");
  544. return false;
  545. }
  546. return true;
  547. }
  548. bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) {
  549. std::string soc_version;
  550. (void)ge::GetContext().GetOption(ge::SOC_VERSION, soc_version);
  551. GELOGI("SetOmSystemInfo soc_version: %s", soc_version.c_str());
  552. if (!ge::AttrUtils::SetStr(obj, "soc_version", soc_version)) {
  553. GELOGW("SetStr of soc_version failed.");
  554. return false;
  555. }
  556. std::string framework_type;
  557. (void)ge::GetContext().GetOption(ge::FRAMEWORK_TYPE, framework_type);
  558. GELOGI("SetOmSystemInfo framework_type: %s", framework_type.c_str());
  559. auto iter = ge::kFwkTypeToStr.find(framework_type);
  560. if (iter == ge::kFwkTypeToStr.end()) {
  561. GELOGW("Can not find framework_type in the map.");
  562. return false;
  563. }
  564. if (!ge::AttrUtils::SetStr(obj, "framework_type", iter->second)) {
  565. GELOGW("SetStr of framework_type failed.");
  566. return false;
  567. }
  568. return true;
  569. }
  570. Status GeGenerator::SetModelNameForDump(const GeRootModelPtr &ge_root_model) {
  571. bool is_unknown_shape = false;
  572. Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  573. if (ret != SUCCESS) {
  574. GELOGE(FAILED, "[Check][IsUnknownShape]Check root model is unknown shape failed, model id:%u",
  575. ge_root_model->GetModelId());
  576. REPORT_CALL_ERROR("E19999", "Check root model is unknown shape failed, model id:%u",
  577. ge_root_model->GetModelId());
  578. return FAILED;
  579. }
  580. GeModelPtr model_root = nullptr;
  581. if (is_unknown_shape) {
  582. model_root = MakeShared<GeModel>();
  583. GE_CHECK_NOTNULL(model_root);
  584. model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
  585. ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
  586. }
  587. ModelHelper model_helper;
  588. string model_name;
  589. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  590. Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
  591. model_name);
  592. if (name_ret != SUCCESS) {
  593. ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
  594. GELOGE(FAILED, "[Check][GetModelNameStep]Get model_name failed. Param --output is invalid, root graph name: %s",
  595. ge_root_model->GetRootGraph()->GetName().c_str());
  596. return PARAM_INVALID;
  597. }
  598. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  599. GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  600. GE_CHECK_NOTNULL(ge_model);
  601. ge_model->SetName(model_name);
  602. return SUCCESS;
  603. }
  604. Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
  605. ModelBufferData &model, bool is_offline) {
  606. rtContext_t ctx = nullptr;
  607. auto rt = rtCtxGetCurrent(&ctx);
  608. if (rt != RT_ERROR_NONE) {
  609. GELOGD("Current ctx is null.");
  610. ctx = nullptr;
  611. }
  612. std::function<void()> callback = [&]() {
  613. if (ctx != nullptr) {
  614. (void)rtCtxSetCurrent(ctx);
  615. }
  616. };
  617. GE_MAKE_GUARD(restore, callback);
  618. GeRootModelPtr ge_root_model = nullptr;
  619. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  620. impl_->is_offline_ = is_offline;
  621. Status ret = impl_->BuildModel(graph, inputs, ge_root_model);
  622. if (ret != SUCCESS) {
  623. GELOGE(ret, "[Build][Model] failed, ret:%d.", ret);
  624. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  625. GELOGE(FAILED, "[Call][Finalize] graph_manager finalize fail.");
  626. }
  627. return ret;
  628. }
  629. /// BUILD_MODE_TUNING with BUILD_STEP_BEFORE_UB_MATCH no need save model;
  630. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER no need save model;
  631. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need save model.
  632. if ((impl_->build_mode_ == BUILD_MODE_TUNING) &&
  633. (impl_->build_step_ == BUILD_STEP_BEFORE_UB_MATCH || impl_->build_step_ == BUILD_STEP_AFTER_BUILDER ||
  634. impl_->build_step_ == BUILD_STEP_AFTER_BUILDER_SUB)) {
  635. GELOGI("Build mode:%s with step:%s no need SaveModel.",
  636. impl_->build_mode_.c_str(),
  637. impl_->build_step_.c_str());
  638. return SUCCESS;
  639. }
  640. GE_CHECK_NOTNULL(ge_root_model);
  641. ret = SetModelNameForDump(ge_root_model);
  642. if (ret != SUCCESS) {
  643. return ret;
  644. }
  645. ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
  646. if (ret != SUCCESS) {
  647. GELOGE(ret, "[Save][RootModel] failed, ret:%d, file:%s", ret, file_name_prefix.c_str());
  648. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  649. GELOGE(FAILED, "graph_manager finalize fail.");
  650. }
  651. return ret;
  652. }
  653. return SUCCESS;
  654. }
  655. namespace {
  656. bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tensor_desc) {
  657. bool is_need = true;
  658. // format and dtype is all reserved, stand for Optional input. When singleop scene
  659. if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) {
  660. is_need = false;
  661. }
  662. return is_need;
  663. }
  664. Status CheckDynamicSupport(GeModelPtr &ge_model, const ComputeGraphPtr &graph) {
  665. bool support_dynamic = true;
  666. bool is_dynamic = false;
  667. for (const auto &node : graph->GetDirectNode()) {
  668. GE_CHECK_NOTNULL(node);
  669. auto op_desc = node->GetOpDesc();
  670. GE_CHECK_NOTNULL(op_desc);
  671. if (op_desc->GetOpEngineName() != kAIcoreEngine) {
  672. continue;
  673. }
  674. if (AttrUtils::HasAttr(op_desc, kAttrSupportDynamicShape)) {
  675. is_dynamic = true;
  676. (void) AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic);
  677. if (!support_dynamic) {
  678. GELOGW("Node[%s] doesn't support dynamic shape.", node->GetName().c_str());
  679. break;
  680. }
  681. }
  682. }
  683. if (is_dynamic) {
  684. (void) AttrUtils::SetBool(ge_model, kAttrSupportDynamicShape, support_dynamic);
  685. }
  686. return SUCCESS;
  687. }
  688. }
  689. bool GeGenerator::CheckNoAicore(const ComputeGraphPtr &graph) {
  690. for (const auto &node : graph->GetDirectNode()) {
  691. if (node == nullptr) {
  692. continue;
  693. }
  694. auto op_desc = node->GetOpDesc();
  695. if (op_desc == nullptr) {
  696. continue;
  697. }
  698. if (op_desc->GetOpEngineName() == kAIcoreEngine) {
  699. return false;
  700. }
  701. }
  702. return true;
  703. }
  704. void GeGenerator::RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs) {
  705. for (auto &input : inputs) {
  706. GeTensorDesc input_desc = input.GetTensorDesc();
  707. bool is_const = false;
  708. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  709. bool is_optional = IsOptional(input_desc);
  710. if (!is_optional && !is_const) {
  711. outputs.emplace_back(input);
  712. }
  713. }
  714. }
  715. Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  716. const vector<GeTensor> &outputs) {
  717. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  718. if (!inputs.empty() && (inputs.size() != op_desc->GetAllInputsSize())) {
  719. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  720. {op_desc->GetName(), op_desc->GetType(), "inputs size" + FmtToStr(op_desc->GetAllInputsSize()),
  721. "tensor size is " + FmtToStr(inputs.size())});
  722. GELOGE(PARAM_INVALID, "[Check][Param] Tensor size: %zu, op:%s(%s) Inputs size: %zu, not equal",
  723. inputs.size(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), op_desc->GetAllInputsSize());
  724. return PARAM_INVALID;
  725. }
  726. if (!outputs.empty() && (outputs.size() != op_desc->GetOutputsSize())) {
  727. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  728. {op_desc->GetName(), op_desc->GetType(), "outputs size" + FmtToStr(op_desc->GetOutputsSize()),
  729. "tensor size is " + FmtToStr(outputs.size())});
  730. GELOGE(PARAM_INVALID, "[Check][Param] Tensor size: %zu, op:%s(%s) Outputs size: %zu, not equal",
  731. outputs.size(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), op_desc->GetOutputsSize());
  732. return PARAM_INVALID;
  733. }
  734. return SUCCESS;
  735. }
  736. Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph) {
  737. GE_CHECK_NOTNULL(op_desc);
  738. if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) {
  739. auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType());
  740. if (node_op.IsEmpty()) {
  741. GELOGW("get op from OperatorFactory fail. op type: %s", op_desc->GetType().c_str());
  742. } else {
  743. GELOGD("get op from OperatorFactory success. op type: %s", op_desc->GetType().c_str());
  744. auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  745. if (temp_op_desc == nullptr) {
  746. REPORT_INNER_ERROR("E19999", "GetOpDescFromOperator failed, as return nullptr, type:%s",
  747. op_desc->GetType().c_str());
  748. GELOGE(FAILED, "[Get][OpDesc] temp op desc is null, type:%s", op_desc->GetType().c_str());
  749. return FAILED;
  750. }
  751. if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
  752. GELOGW("InferFormatForSingleOp UpdateInputName failed");
  753. }
  754. if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
  755. GELOGW("InferFormatForSingleOp UpdateOutputName failed");
  756. }
  757. }
  758. node_op.BreakConnect();
  759. }
  760. auto comp_graph = GraphUtils::GetComputeGraph(graph);
  761. GE_CHECK_NOTNULL(comp_graph);
  762. auto node = comp_graph->FindNode(op_desc->GetName());
  763. GE_CHECK_NOTNULL(node);
  764. auto op = OpDescUtils::CreateOperatorFromNode(node);
  765. auto ret = op_desc->CallInferFormatFunc(op);
  766. if (ret != GRAPH_SUCCESS) {
  767. REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail",
  768. op_desc->GetName().c_str());
  769. GELOGE(FAILED, "[Call][InferFormatFunc] for single op:%s fail.", op_desc->GetName().c_str());
  770. return FAILED;
  771. }
  772. return SUCCESS;
  773. }
  774. Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
  775. const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
  776. bool is_offline, int32_t compile_flag) {
  777. GELOGD("Inputs size is %zu, outputs size is %zu.", inputs.size(), outputs.size());
  778. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  779. impl_->is_offline_ = is_offline;
  780. (void)AttrUtils::SetBool(op_desc, ATTR_SINGLE_OP_SCENE, true);
  781. if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
  782. GELOGE(PARAM_INVALID, "[Check][Param] input param is invalid when build single op:%s!",
  783. op_desc->GetName().c_str());
  784. return PARAM_INVALID;
  785. }
  786. OmgContext &omg_context = impl_->omg_context_;
  787. omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc);
  788. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  789. impl_->is_singleop_unregistered_ = true;
  790. }
  791. // 0. Save original attributes.
  792. OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc);
  793. GE_CHECK_NOTNULL(op_desc_tmp);
  794. bool fuzz_compile_flag = false;
  795. if (!HasShapeRange(inputs) && compile_flag == kFuzzBuildPattern) {
  796. fuzz_compile_flag = true;
  797. }
  798. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_FUZZ_BUILD, fuzz_compile_flag);
  799. impl_->omg_context_.fuzz_compile_flag = fuzz_compile_flag;
  800. // 1. Create ComputeGraph.
  801. string name = ge::CurrentTimeInStr() + "_" + model_file_name;
  802. Graph graph;
  803. GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph),
  804. "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str());
  805. GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc, graph));
  806. // 2. check engine type when compile online
  807. if (model_file_name == kFileNameSuffix) {
  808. auto comp_graph = GraphUtils::GetComputeGraph(graph);
  809. GE_CHECK_NOTNULL(comp_graph);
  810. auto node = comp_graph->FindNode(op_desc->GetName());
  811. Status ret = CheckEngineTypeSupport(node, engine_type);
  812. if (ret != SUCCESS) {
  813. GELOGE(ret, "[Check][EngineType]not support node:%s with engine of %d.", node->GetName().c_str(), engine_type);
  814. return ret;
  815. }
  816. }
  817. GELOGI("ATC parser success in single op build.");
  818. GeRootModelPtr ge_root_model = nullptr;
  819. vector<GeTensor> data_inputs;
  820. RemoveConst(inputs, data_inputs);
  821. GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, data_inputs, ge_root_model));
  822. map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs();
  823. GE_CHECK_NOTNULL(ge_root_model);
  824. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  825. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  826. if (name_to_ge_model.empty()) {
  827. REPORT_CALL_ERROR("E19999", "GetSubgraphInstanceNameToModel failed.");
  828. GELOGE(PARAM_INVALID, "[Get][Name] GetSubgraphInstanceNameToModel is empty.");
  829. return PARAM_INVALID;
  830. }
  831. const ComputeGraphPtr root_graph = ge_root_model->GetRootGraph();
  832. GeModelPtr &ge_model = name_to_ge_model.begin()->second;
  833. GE_CHK_STATUS_RET_NOLOG(CheckDynamicSupport(ge_model, root_graph));
  834. GELOGI("After build model, The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str());
  835. bool all_shape = false;
  836. (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape);
  837. GELOGD("Node: %s, all_shape is %d, compile_flag is %d.", op_desc->GetName().c_str(), all_shape, compile_flag);
  838. (void)AttrUtils::SetInt(ge_model, ATTR_NAME_BUILD_MODE, fuzz_compile_flag);
  839. if (all_shape) {
  840. (void)AttrUtils::SetBool(ge_model, kAicpuAllshape, all_shape);
  841. }
  842. if (all_shape && CheckNoAicore(root_graph)) {
  843. GELOGD("Get aicpu all_shape kernel!");
  844. vector<GeTensor> inputs_dynamic;
  845. vector<GeTensor> outputs_dynamic;
  846. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(inputs, inputs_dynamic));
  847. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(outputs, outputs_dynamic));
  848. GE_CHK_STATUS_RET_NOLOG(
  849. impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic));
  850. } else if (fuzz_compile_flag) {
  851. GeAttrValue::LIST_NAMED_ATTRS fuzz_build_attrs;
  852. if (GetFuzzBuildAttrs(op_desc, ge_root_model, fuzz_build_attrs) != SUCCESS) {
  853. GELOGE(FAILED, "[Get][FuzzRet]Failed to get fuzz build result of %s.", op_desc->GetName().c_str());
  854. return FAILED;
  855. }
  856. if (!fuzz_build_attrs.empty()) {
  857. GE_CHK_BOOL_EXEC(AttrUtils::SetListNamedAttrs(ge_model, ATTR_NAME_FUZZ_BUILD_RES_ATTRS, fuzz_build_attrs),
  858. REPORT_CALL_ERROR("E19999", "Set model:%s(id:%u) attr:%s failed.",
  859. ge_model->GetName().c_str(), ge_model->GetModelId(),
  860. ATTR_NAME_FUZZ_BUILD_RES_ATTRS.c_str());
  861. return FAILED, "Set model:%s(id:%u) attr:%s failed.",
  862. ge_model->GetName().c_str(), ge_model->GetModelId(), ATTR_NAME_FUZZ_BUILD_RES_ATTRS.c_str());
  863. }
  864. GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
  865. } else {
  866. GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
  867. }
  868. GELOGI("Start save GeModel to Model buffer");
  869. GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff));
  870. return SUCCESS;
  871. }
  872. /**
  873. * @ingroup ge
  874. * @brief Compiling a single operator into an offline model
  875. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  876. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  877. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  878. * @param [in] const string &model_file_name: Offline model filename.
  879. * @param [in] compile_flag: op build flag from atc
  880. * @return SUCCESS handle successfully / others handle failed
  881. */
  882. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  883. const vector<GeTensor> &outputs, const string &model_file_name,
  884. int32_t compile_flag) {
  885. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  886. GELOGI("Start to build single op offline model, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  887. ModelBufferData model_buff;
  888. OpEngineType engine_type = ENGINE_SYS;
  889. Status status = BuildSingleOp(op_desc, inputs, outputs, model_file_name, engine_type, model_buff, true, compile_flag);
  890. GELOGI("Finish build single offline model, status: %u", status);
  891. return status;
  892. }
  893. /**
  894. * @ingroup ge
  895. * @brief Compiling a single operator into online buffer
  896. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  897. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  898. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  899. * @param [in] engine_type: specific engine.
  900. * @param [in] compile_flag: op build flag, compile flag by acl
  901. * @param [out] ModelBufferData &Model_buff: Model_buff: model buffer of the op.
  902. * @return SUCCESS handle successfully / others handle failed
  903. */
  904. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  905. const vector<GeTensor> &outputs, OpEngineType engine_type,
  906. ModelBufferData &model_buff) {
  907. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  908. GELOGI("Start to build single op online, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  909. Status status = BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false);
  910. GELOGI("Finish build single online model, status: %u", status);
  911. return status;
  912. }
  913. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  914. const vector<GeTensor> &outputs, OpEngineType engine_type, int32_t compile_flag,
  915. ModelBufferData &model_buff) {
  916. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  917. GELOGI("Start to build single op online, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  918. Status status = BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false,
  919. compile_flag);
  920. GELOGI("Finish build single online model, status: %u", status);
  921. return status;
  922. }
  923. Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  924. const vector<GeTensor> &outputs, std::string graph_name, Graph &graph) {
  925. ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(graph_name);
  926. GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR);
  927. // 1. Add Node to ComputeGraph.
  928. NodePtr op_node = compute_graph->AddNode(op_desc);
  929. GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR);
  930. // 2. Create InputData node.
  931. int32_t arg_index = 0;
  932. int32_t data_index = 0;
  933. if (inputs.empty()) {
  934. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  935. GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
  936. if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
  937. continue;
  938. }
  939. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false, data_index));
  940. arg_index++;
  941. }
  942. } else {
  943. for (const auto &in_desc : inputs) {
  944. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true, data_index));
  945. arg_index++;
  946. }
  947. }
  948. // 3. Create Output node.
  949. if (!outputs.empty()) {
  950. GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs));
  951. }
  952. // dump ComputeGraph node.
  953. compute_graph->Dump();
  954. graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  955. return SUCCESS;
  956. }
  957. Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  958. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) {
  959. GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID);
  960. GE_CHK_BOOL_EXEC_NOLOG(graph_manager_.SaveParams(*ge_model, type, attrs, inputs, outputs) == SUCCESS,
  961. (void)graph_manager_.Finalize();
  962. return FAILED);
  963. return SUCCESS;
  964. }
  965. Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) {
  966. // set atc version
  967. if (!SetAtcVersionInfo(*(model.get()))) {
  968. GELOGW("SetPackageVersionInfo of atc failed!");
  969. }
  970. // set opp version
  971. if (!SetOppVersionInfo(*(model.get()))) {
  972. GELOGW("SetPackageVersionInfo of ops failed!");
  973. }
  974. ModelHelper model_helper;
  975. model_helper.SetSaveMode(is_offline_);
  976. Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff);
  977. if (ret != SUCCESS) {
  978. GELOGE(ret, "[Call][SaveToOmModel] Save to om model failed");
  979. return ret;
  980. }
  981. return SUCCESS;
  982. }
  983. Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootModelPtr &ge_root_model,
  984. ModelBufferData &model_buff) {
  985. bool is_unknown_shape = false;
  986. auto ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  987. if (ret != SUCCESS) {
  988. REPORT_CALL_ERROR("E19999", "root model(id:%u) CheckIsUnknownShape failed, ret:%d",
  989. ge_root_model->GetModelId(), ret);
  990. GELOGE(FAILED, "[Check][RootModel] is unkonwn shape failed, ret:%d", ret);
  991. return FAILED;
  992. }
  993. GELOGD("begin save root model, cur model is unkonwn shape model ? : %d", is_unknown_shape);
  994. GE_CHK_BOOL_EXEC(!ge_root_model->GetSubgraphInstanceNameToModel().empty(),
  995. REPORT_CALL_ERROR("E19999", "root model(id:%u) has no sub model.", ge_root_model->GetModelId());
  996. return FAILED, "[Get][SubModel] ge root model has no sub model")
  997. GeModelPtr model_root = nullptr;
  998. if (is_unknown_shape) {
  999. auto name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  1000. model_root = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  1001. } else {
  1002. model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second;
  1003. }
  1004. GE_CHECK_NOTNULL(model_root);
  1005. // set atc version
  1006. if (!SetAtcVersionInfo(*(model_root.get()))) {
  1007. GELOGW("SetPackageVersionInfo of atc failed!");
  1008. }
  1009. // set opp version
  1010. if (!SetOppVersionInfo(*(model_root.get()))) {
  1011. GELOGW("SetPackageVersionInfo of ops failed!");
  1012. }
  1013. if (!SetOmSystemInfo(*(model_root.get()))) {
  1014. GELOGW("SetOmsystemInfo failed!");
  1015. }
  1016. ModelHelper model_helper;
  1017. model_helper.SetSaveMode(is_offline_);
  1018. ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape);
  1019. if (ret != SUCCESS) {
  1020. REPORT_CALL_ERROR("E19999", "SaveToOmRootModel failed, ret:%d, model id:%u", ret, ge_root_model->GetModelId());
  1021. GELOGE(ret, "[Call][SaveToOmRootModel] failed, ret:%d, model id:%u", ret, ge_root_model->GetModelId());
  1022. return ret;
  1023. }
  1024. return SUCCESS;
  1025. }
  1026. Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs,
  1027. GeRootModelPtr &ge_root_model) {
  1028. static std::atomic<GraphId> atomic_graph_id(0);
  1029. auto graph_id = atomic_graph_id.fetch_add(1);
  1030. const std::map<std::string, std::string> options;
  1031. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  1032. if (ret != SUCCESS) {
  1033. REPORT_CALL_ERROR("E19999", "add graph(id:%u) failed, ret:%d", graph_id, ret);
  1034. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "[Add][Graph] fail, graph id: %u", graph_id);
  1035. (void)graph_manager_.Finalize();
  1036. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  1037. }
  1038. graph_manager_.SetOptionsRunGraphFlag(false);
  1039. // prevent from repeatation of session_id in mix using singleop and graph mode.
  1040. static std::atomic<uint64_t> atomic_session_id(kInitialSessionId);
  1041. auto session_id = atomic_session_id.fetch_add(1);
  1042. // This is a temporary add for graph with variable
  1043. auto version = static_cast<int32_t>(SessionVersion::ClOUD_VERSION);
  1044. ret = VarManager::Instance(session_id)->Init(version, session_id, kDefaultDeviceId, kDefaultJobId);
  1045. GELOGI("Start init var instance, session_id %lu", session_id);
  1046. if (ret != SUCCESS) {
  1047. GELOGW("Failed init var instance, session_id %lu", session_id);
  1048. }
  1049. if (is_singleop_unregistered_) {
  1050. ret = graph_manager_.BuildGraphForUnregisteredOp(graph_id, inputs, ge_root_model, session_id);
  1051. } else {
  1052. ret = graph_manager_.BuildGraph(graph_id, inputs, ge_root_model, session_id);
  1053. }
  1054. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  1055. if (ret != SUCCESS) {
  1056. REPORT_CALL_ERROR("E19999", "build graph failed, graph id:%u, ret:%d", graph_id, ret);
  1057. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "[Build][Graph] fail, graph id: %u", graph_id);
  1058. }
  1059. RtContextUtil::GetInstance().DestroyRtContexts(session_id);
  1060. Analyzer::GetInstance()->DestroySessionJsonObject(session_id);
  1061. VarManagerPool::Instance().RemoveVarManager(session_id);
  1062. return ret;
  1063. }
  1064. Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph) {
  1065. static std::atomic<GraphId> atomic_graph_id(0);
  1066. auto graph_id = atomic_graph_id.fetch_add(1);
  1067. const std::map<std::string, std::string> options;
  1068. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  1069. if (ret != SUCCESS) {
  1070. REPORT_CALL_ERROR("E19999", "add graph failed, graph id:%u, ret:%d", graph_id, ret);
  1071. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "[Add][Graph] failed, graph id: %u", graph_id);
  1072. (void)graph_manager_.Finalize();
  1073. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  1074. }
  1075. ret = graph_manager_.GenerateInfershapeGraph(graph_id);
  1076. if (ret != SUCCESS) {
  1077. REPORT_CALL_ERROR("E19999", "GenerateInfershapeGraph failed, graph id:%u, ret:%d", graph_id, ret);
  1078. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED,
  1079. "[Generate][Graph] failed, graph id:%u, ret:%d", graph_id, ret);
  1080. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  1081. }
  1082. return SUCCESS;
  1083. }
  1084. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示