You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hybrid_model.cc 18 kB

5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "hybrid/model/hybrid_model.h"
  17. #include <vector>
  18. #include "graph/debug/ge_attr_define.h"
  19. #include "graph/load/model_manager/model_utils.h"
  20. #include "graph/utils/graph_utils.h"
  21. #include "graph/utils/node_utils.h"
  22. #include "graph/utils/tensor_utils.h"
  23. #include "graph/utils/type_utils.h"
  24. #include "hybrid/common/npu_memory_allocator.h"
  25. #include "hybrid/model/hybrid_model_builder.h"
  26. #include "hybrid/node_executor/node_executor.h"
  27. #include "framework/common/op/ge_op_utils.h"
  28. namespace ge {
  29. namespace hybrid {
  30. namespace {
  31. const int64_t kMemSizeUnknownShape = -1; // Unknown shape mem size
  32. }
  33. HybridModel::HybridModel(GeRootModelPtr ge_model) : ge_root_model_(std::move(ge_model)) {
  34. }
  35. HybridModel::~HybridModel() {
  36. GELOGD("[%s] HybridModel destroyed.", model_name_.c_str());
  37. }
  38. Status HybridModel::Init(bool is_single_op) {
  39. GELOGD("Start to init hybrid model.");
  40. is_single_op_ = is_single_op;
  41. if (is_single_op) {
  42. GE_CHK_STATUS_RET(HybridModelBuilder(*this).BuildForSingleOp(), "[Build][HybridModel] for SingleOp failed.");
  43. } else {
  44. GE_CHK_STATUS_RET(HybridModelBuilder(*this).Build(), "[Build][HybridModel] failed.");
  45. }
  46. SaveSpecifyAttrValues();
  47. GELOGD("HybridModel initialized successfully.");
  48. return SUCCESS;
  49. }
  50. TensorValue *HybridModel::GetVariable(const string &name) const {
  51. auto it = variable_tensors_.find(name);
  52. if (it == variable_tensors_.end()) {
  53. GELOGD("Failed to get variable tensor. var name = [%s]", name.c_str());
  54. return nullptr;
  55. }
  56. GELOGD("Got variable tensor. var name = [%s], tensor = %s", name.c_str(), it->second->DebugString().c_str());
  57. return it->second.get();
  58. }
  59. NodePtr HybridModel::GetVariableNode(const string &name) const {
  60. auto it = device_variable_nodes_.find(name);
  61. if (it != device_variable_nodes_.end()) {
  62. return it->second;
  63. }
  64. auto host_find = host_variable_nodes_.find(name);
  65. if (host_find != host_variable_nodes_.end()) {
  66. return host_find->second;
  67. }
  68. GELOGD("Failed to get variable node by name = [%s]", name.c_str());
  69. return nullptr;
  70. }
  71. const std::vector<domi::TaskDef> *HybridModel::GetTaskDefs(const NodePtr &node) const {
  72. auto it = task_defs_.find(node);
  73. if (it == task_defs_.end()) {
  74. return nullptr;
  75. }
  76. return &it->second;
  77. }
  78. NodeItem *HybridModel::MutableNodeItem(const NodePtr &node) {
  79. auto it = node_items_.find(node);
  80. if (it == node_items_.end()) {
  81. return nullptr;
  82. }
  83. return it->second.get();
  84. }
  85. const NodeItem *HybridModel::GetNodeItem(const NodePtr &node) const {
  86. auto it = node_items_.find(node);
  87. if (it == node_items_.end()) {
  88. return nullptr;
  89. }
  90. return it->second.get();
  91. }
  92. GeModelPtr HybridModel::GetGeModel(const NodePtr &node) const {
  93. auto it = known_shape_sub_models_.find(node);
  94. if (it == known_shape_sub_models_.end()) {
  95. GELOGE(INTERNAL_ERROR, "[Check][Param:node][%s] Failed to get GeModel for subgraph node,"
  96. "because node not in known_shape_sub_models_.", node->GetName().c_str());
  97. REPORT_INNER_ERROR("E19999", "%s Failed to get GeModel for subgraph node,"
  98. "because node not in known_shape_sub_models_.", node->GetName().c_str());
  99. return nullptr;
  100. }
  101. return it->second;
  102. }
  103. const GraphItem *HybridModel::GetRootGraphItem() const {
  104. return root_graph_item_.get();
  105. }
  106. const ComputeGraphPtr &HybridModel::GetRootGraph() const {
  107. return root_graph_;
  108. }
  109. const GraphItem *HybridModel::GetSubgraphItem(const std::string &graph_name) const {
  110. GELOGD("To find subgraph item by name = %s", graph_name.c_str());
  111. auto it = subgraph_items_.find(graph_name);
  112. if (it == subgraph_items_.end()) {
  113. GELOGD("Subgraph item not found by node = %s", graph_name.c_str());
  114. return nullptr;
  115. }
  116. return it->second.get();
  117. }
  118. const GraphItem *HybridModel::GetSubgraphItem(const ComputeGraphPtr &subgraph) const {
  119. if (subgraph == nullptr) {
  120. REPORT_INNER_ERROR("E19999", "Input param subgraph is nullptr, Graph:%s",
  121. root_graph_item_->GetName().c_str());
  122. GELOGE(PARAM_INVALID, "[Check][Param]subgraph is nullptr. graph:%s",
  123. root_graph_item_->GetName().c_str());
  124. return nullptr;
  125. }
  126. auto subgraph_name = subgraph->GetName();
  127. return GetSubgraphItem(subgraph_name);
  128. }
  129. const string &HybridModel::GetModelName() const {
  130. return model_name_;
  131. }
  132. Status HybridModel::GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info, int32_t &dynamic_type) {
  133. // dynamic shape do not need dynamic batch
  134. batch_info = {};
  135. dynamic_type = -1;
  136. return SUCCESS;
  137. }
  138. void HybridModel::GetUserDesignateShapeOrder(std::vector<std::string> &user_input_shape_order) {
  139. // dynamic shape do not need dynamic batch
  140. user_input_shape_order = {};
  141. }
  142. void HybridModel::GetModelAttr(std::vector<std::string> &dynamic_output_shape_info) {
  143. dynamic_output_shape_info = {};
  144. }
  145. Status HybridModel::GetInputOutputDescInfo(vector<InputOutputDescInfo> &input_desc,
  146. vector<InputOutputDescInfo> &output_desc,
  147. std::vector<uint32_t> &input_formats,
  148. std::vector<uint32_t> &output_formats) {
  149. auto node_item_list = root_graph_item_->GetInputNodes();
  150. if (node_item_list.empty()) {
  151. REPORT_INNER_ERROR("E19999", "node item list is empty!, graph:%s",
  152. root_graph_item_->GetName().c_str());
  153. GELOGE(FAILED, "[Get][InputNodes]node item list is empty!, graph:%s",
  154. root_graph_item_->GetName().c_str());
  155. return FAILED;
  156. }
  157. GE_CHECK_NOTNULL(node_item_list[0]->node);
  158. GE_CHECK_NOTNULL(node_item_list[0]->node->GetOpDesc());
  159. if (node_item_list[0]->node->GetOpDesc()->GetInputsSize() != 1) {
  160. REPORT_INNER_ERROR("E19999", "Input size of op is not 1, op:%s, type:%s",
  161. node_item_list[0]->node->GetName().c_str(),
  162. node_item_list[0]->node->GetType().c_str());
  163. GELOGE(FAILED, "[Check][Size]input size of op is not 1! op:%s, type:%s",
  164. node_item_list[0]->node->GetName().c_str(),
  165. node_item_list[0]->node->GetType().c_str());
  166. return FAILED;
  167. }
  168. GE_CHK_STATUS_RET(GetInputDescInfo(input_desc, input_formats), "[Get][InputDescInfo] failed.");
  169. GE_CHK_STATUS_RET(GetOutputDescInfo(output_desc, output_formats), "[Get][OutputDescInfo] failed.");
  170. return SUCCESS;
  171. }
  172. void HybridModel::SetInputDimsAndShapeRangesInfo(const vector<int64_t> &model_input_dims,
  173. std::vector<std::pair<int64_t, int64_t>> &shape_ranges,
  174. InputOutputDescInfo &input) {
  175. for (auto model_input_dim : model_input_dims) {
  176. input.shape_info.dims.push_back(model_input_dim);
  177. }
  178. input.shape_info.shape_ranges = shape_ranges;
  179. return;
  180. }
  181. void HybridModel::CreateInputDimsInfo(const OpDescPtr &op_desc, InputOutputDescInfo &input) {
  182. std::vector<std::pair<int64_t,int64_t>> shape_ranges;
  183. if (is_new_model_desc_ && op_desc->HasAttr(ATTR_NAME_INPUT_DIMS)) {
  184. // When static aipp is set, need to get the model input dims which processed by aipp
  185. vector<int64_t> model_input_dims;
  186. (void)AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_DIMS, model_input_dims);
  187. SetInputDimsAndShapeRangesInfo(model_input_dims, shape_ranges, input);
  188. return;
  189. }
  190. // judge if this data is linked dynamic aipp first, multiply batch has been considered
  191. if (op_desc->HasAttr("_dynamic_aipp_input_dims")) {
  192. vector<int64_t> dynamic_aipp_input_dims;
  193. (void)AttrUtils::GetListInt(op_desc, "_dynamic_aipp_input_dims", dynamic_aipp_input_dims);
  194. SetInputDimsAndShapeRangesInfo(dynamic_aipp_input_dims, shape_ranges, input);
  195. return;
  196. } else {
  197. vector<int64_t> input_dims = op_desc->GetInputDescPtr(0)->GetShape().GetDims();
  198. op_desc->GetInputDescPtr(0)->GetShapeRange(shape_ranges);
  199. SetInputDimsAndShapeRangesInfo(input_dims, shape_ranges, input);
  200. return;
  201. }
  202. }
  203. Status HybridModel::GetInputDescInfo(vector<InputOutputDescInfo> &input_desc, std::vector<uint32_t> &formats) {
  204. auto node_item_list = root_graph_item_->GetInputNodes();
  205. for (auto &node_item : node_item_list) {
  206. InputOutputDescInfo input;
  207. GE_CHECK_NOTNULL(node_item->node);
  208. auto op_desc = node_item->node->GetOpDesc();
  209. GE_CHECK_NOTNULL(op_desc);
  210. GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(0));
  211. Format format = op_desc->GetInputDescPtr(0)->GetFormat();
  212. DataType data_type = op_desc->GetInputDescPtr(0)->GetDataType();
  213. input.data_type = static_cast<uint32_t>(data_type);
  214. input.name = op_desc->GetName();
  215. GeShape shape = op_desc->GetInputDescPtr(0)->GetShape();
  216. int64_t tensor_size = 0;
  217. if (TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size) != GRAPH_SUCCESS) {
  218. GELOGE(FAILED, "[Calculate][TensorMemSize] failed input0 desc in node:%s."
  219. "shape:%s, format:%s, datatype:%s.", op_desc->GetName().c_str(),
  220. shape.ToString().c_str(), TypeUtils::FormatToSerialString(format).c_str(),
  221. TypeUtils::DataTypeToSerialString(data_type).c_str());
  222. REPORT_CALL_ERROR("E19999", "CalcTensorMemSize failed for input0 desc in node:%s,"
  223. "shape:%s, format:%s, datatype:%s", op_desc->GetName().c_str(),
  224. shape.ToString().c_str(), TypeUtils::FormatToSerialString(format).c_str(),
  225. TypeUtils::DataTypeToSerialString(data_type).c_str());
  226. return FAILED;
  227. }
  228. if (tensor_size == kMemSizeUnknownShape) {
  229. tensor_size = 0;
  230. }
  231. input.size = static_cast<uint64_t>(tensor_size);
  232. CreateInputDimsInfo(op_desc, input);
  233. formats.push_back(format);
  234. input_desc.push_back(input);
  235. }
  236. is_new_model_desc_ = false;
  237. return SUCCESS;
  238. }
  239. void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc,
  240. InputOutputDescInfo &output_desc_info, uint32_t &format_result) {
  241. GE_IF_BOOL_EXEC(output_desc == nullptr,
  242. REPORT_INNER_ERROR("E19999", "param output_desc is nullptr, check invalid.");
  243. GELOGE(FAILED, "[Check][Param:output_desc]output desc ptr is nullptr");
  244. return );
  245. Format format = output_desc->GetFormat();
  246. GeShape shape = output_desc->GetShape();
  247. std::vector<std::pair<int64_t,int64_t>> shape_ranges;
  248. output_desc->GetShapeRange(shape_ranges);
  249. DataType data_type = output_desc->GetDataType();
  250. format_result = format;
  251. if (format == FORMAT_FRACTAL_Z) { // FraczToHWCK
  252. int64_t k = shape.GetDim(0); // 0: first dim
  253. int64_t c = shape.GetDim(1); // 1: second dim
  254. int64_t h = shape.GetDim(2); // 2: third dim
  255. int64_t w = shape.GetDim(3); // 3: forth dim
  256. output_desc_info.shape_info.dims.push_back(h);
  257. output_desc_info.shape_info.dims.push_back(w);
  258. output_desc_info.shape_info.dims.push_back(c);
  259. output_desc_info.shape_info.dims.push_back(k);
  260. if (shape_ranges.size() == 4) { // 4 dims
  261. output_desc_info.shape_info.shape_ranges.push_back(shape_ranges[2]); // h:2
  262. output_desc_info.shape_info.shape_ranges.push_back(shape_ranges[3]); // w:3
  263. output_desc_info.shape_info.shape_ranges.push_back(shape_ranges[1]); // c:1
  264. output_desc_info.shape_info.shape_ranges.push_back(shape_ranges[0]); // k:0
  265. }
  266. format_result = FORMAT_HWCN;
  267. } else {
  268. for (size_t j = 0; j < shape.GetDimNum(); j++) {
  269. output_desc_info.shape_info.dims.push_back(shape.GetDim(j));
  270. }
  271. output_desc_info.shape_info.shape_ranges = shape_ranges;
  272. }
  273. int64_t tensor_size = 0;
  274. (void)TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size);
  275. if (tensor_size == kMemSizeUnknownShape) {
  276. tensor_size = 0;
  277. }
  278. output_desc_info.size = static_cast<uint64_t>(tensor_size);
  279. output_desc_info.data_type = output_desc->GetDataType();
  280. }
  281. Status HybridModel::GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc, std::vector<uint32_t> &formats) {
  282. std::vector<ConstGeTensorDescPtr> output_desc_list;
  283. // output_desc_list contains vaild input desc
  284. GE_CHK_STATUS_RET(root_graph_item_->GetOutputDescList(output_desc_list),
  285. "[Invoke][GetOutputDescList]get output desc info failed, Graph:%s",
  286. root_graph_item_->GetName().c_str());
  287. vector<std::string> out_node_names;
  288. (void)ge::AttrUtils::GetListStr(ge_root_model_->GetRootGraph(), ATTR_MODEL_OUT_NODES_NAME, out_node_names);
  289. GE_CHECK_NOTNULL(root_graph_item_->GetOutputNode());
  290. auto op_desc = root_graph_item_->GetOutputNode()->op_desc;
  291. GE_CHECK_NOTNULL(op_desc);
  292. auto out_size = static_cast<uint32_t>(op_desc->GetInputsSize());
  293. GE_IF_BOOL_EXEC(out_size != output_desc_list.size(),
  294. REPORT_INNER_ERROR("E19999", "output size[%u] not match output_desc_list size[%zu]",
  295. out_size, output_desc_list.size());
  296. GELOGE(FAILED, "[Check][Size]output size[%u] not match output_desc_list size[%zu]",
  297. out_size, output_desc_list.size());
  298. return FAILED;);
  299. for (uint32_t index = 0; index < out_size; ++index) {
  300. string output_name;
  301. std::vector<std::string> src_name = op_desc->GetSrcName();
  302. std::vector<int64_t> src_index = op_desc->GetSrcIndex();
  303. if (out_size == out_node_names.size()) {
  304. bool contains_colon = out_node_names[index].find(":") != std::string::npos;
  305. output_name = contains_colon ? out_node_names[index] : out_node_names[index] +
  306. ":" + std::to_string(src_index[index]);
  307. } else {
  308. output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] +
  309. "_" + std::to_string(src_index[index]);
  310. }
  311. InputOutputDescInfo output_desc_info;
  312. output_desc_info.name = output_name;
  313. uint32_t format_result;
  314. CreateOutput(output_desc_list[index], output_desc_info, format_result);
  315. output_desc.push_back(output_desc_info);
  316. formats.push_back(format_result);
  317. }
  318. return SUCCESS;
  319. }
  320. TensorValue *HybridModel::GetConstant(const NodePtr &node) const {
  321. if (node == nullptr) {
  322. GELOGE(PARAM_INVALID, "[Check][Param:node]node is null.");
  323. REPORT_INNER_ERROR("E19999", "param node is null, check invalid.");
  324. return nullptr;
  325. }
  326. auto it = constant_tensors_.find(node);
  327. if (it == constant_tensors_.end()) {
  328. GELOGD("constant not found, node name = [%s]", node->GetName().c_str());
  329. return nullptr;
  330. }
  331. GELOGD("Got constant tensor, node name = [%s], tensor = %s",
  332. node->GetName().c_str(),
  333. it->second->DebugString().c_str());
  334. return it->second.get();
  335. }
  336. TensorValue *HybridModel::GetTensor(const NodePtr &node) const {
  337. if (node == nullptr) {
  338. GELOGE(PARAM_INVALID, "[Check][Param:node]node is null.");
  339. REPORT_INNER_ERROR("E19999", "param node is null, check invalid.");
  340. return nullptr;
  341. }
  342. if (node->GetType() == CONSTANT) {
  343. return GetConstant(node);
  344. }
  345. return GetVariable(node->GetName());
  346. }
  347. const map<int64_t, std::vector<std::pair<int, Tensor>>> &HybridModel::GetHostTensors() const {
  348. return host_tensors_;
  349. }
  350. void *HybridModel::GetGlobalStep() const {
  351. if (global_step_ == nullptr) {
  352. return nullptr;
  353. }
  354. return global_step_->GetData();
  355. }
  356. TensorBuffer *HybridModel::GetModelWeight(const string &subgraph_name) const {
  357. auto it = weight_buffer_map_.find(subgraph_name);
  358. if (it == weight_buffer_map_.end()) {
  359. GELOGD("Model weight not found, subgraph name = %s", subgraph_name.c_str());
  360. return nullptr;
  361. }
  362. return it->second.get();
  363. }
  364. // save specify attr values of op, such as ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES
  365. // it will save more attr values in the future
  366. void HybridModel::SaveSpecifyAttrValues() {
  367. for (const auto &node : root_graph_->GetAllNodes()) {
  368. if (node == nullptr) {
  369. continue;
  370. }
  371. auto op_desc = node->GetOpDesc();
  372. if (op_desc == nullptr) {
  373. continue;
  374. }
  375. std::vector<std::string> value;
  376. if (AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, value)) {
  377. std::map<std::string, std::vector<std::string>> attr_name_to_value;
  378. attr_name_to_value[ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES] = value;
  379. op_name_to_attrs_[op_desc->GetName()] = attr_name_to_value;
  380. GELOGD("Get op:%s attr:%s success.", op_desc->GetName().c_str(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str());
  381. }
  382. }
  383. return;
  384. }
  385. Status HybridModel::GetOpAttr(const std::string &op_name, const std::string &attr_name,
  386. std::string &attr_value) const {
  387. auto itr = op_name_to_attrs_.find(op_name);
  388. if (itr == op_name_to_attrs_.end()) {
  389. GELOGW("Did not save op:%s attr", op_name.c_str());
  390. return SUCCESS;
  391. }
  392. auto attr_itr = itr->second.find(attr_name);
  393. if (attr_itr == itr->second.end()) {
  394. GELOGW("Did not save attr:%s of op:%s", attr_name.c_str(), op_name.c_str());
  395. return SUCCESS;
  396. }
  397. for (const auto &name : attr_itr->second) {
  398. attr_value += "[" + std::to_string(name.size()) + "]" + name;
  399. }
  400. GELOGD("Get attr:%s of op:%s success, attr value:%s", attr_name.c_str(), op_name.c_str(), attr_value.c_str());
  401. return SUCCESS;
  402. }
  403. } // namespace hybrid
  404. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示