You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_context.cc 6.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/build/run_context.h"
  17. #include "common/util.h"
  18. #include "framework/common/debug/ge_log.h"
  19. #include "graph/debug/ge_attr_define.h"
  20. namespace ge {
  21. RunContextUtil::~RunContextUtil() { DestroyRtModelResources(); }
  22. Status RunContextUtil::InitMemInfo(uint8_t *data_mem_base, uint64_t data_mem_size, uint8_t *weight_mem_base,
  23. uint64_t weight_mem_size) {
  24. if ((data_mem_size > 0) && (data_mem_base == nullptr)) {
  25. GELOGE(PARAM_INVALID, "InitMemInfo param data_mem_base is null but data_mem_size = %lu.", data_mem_size);
  26. return PARAM_INVALID;
  27. }
  28. if ((weight_mem_size > 0) && (weight_mem_base == nullptr)) {
  29. GELOGE(PARAM_INVALID, "InitMemInfo param weight_mem_base is null but weight_mem_size = %lu.", weight_mem_size);
  30. return PARAM_INVALID;
  31. }
  32. data_mem_base_ = data_mem_base;
  33. data_mem_size_ = data_mem_size;
  34. weight_mem_base_ = weight_mem_base;
  35. weight_mem_size_ = weight_mem_size;
  36. return SUCCESS;
  37. }
  38. Status RunContextUtil::CreateRtModelResources(uint32_t stream_num, uint32_t event_num, uint32_t label_num) {
  39. // Create rt model
  40. rtError_t rt_ret = rtModelCreate(&rt_model_, 0);
  41. if (rt_ret != RT_ERROR_NONE) {
  42. GELOGE(RT_FAILED, "rtModelCreate failed. rt_ret = %d", static_cast<int>(rt_ret));
  43. return RT_FAILED;
  44. }
  45. // Create rt Stream and bind with model
  46. for (uint32_t i = 0; i < stream_num; ++i) {
  47. rtStream_t stream = nullptr;
  48. rt_ret = rtStreamCreate(&stream, 0);
  49. if (rt_ret != RT_ERROR_NONE) {
  50. GELOGE(RT_FAILED, "rtStreamCreate failed. rt_ret = %d, index = %u", static_cast<int>(rt_ret), i);
  51. return RT_FAILED;
  52. }
  53. stream_list_.emplace_back(stream);
  54. rt_ret = rtModelBindStream(rt_model_, stream, 0);
  55. if (rt_ret != RT_ERROR_NONE) {
  56. GELOGE(RT_FAILED, "Bind stream and model failed. rt_ret = %d, index = %u", static_cast<int>(rt_ret), i);
  57. return RT_FAILED;
  58. }
  59. }
  60. // Create rt event
  61. for (uint32_t i = 0; i < event_num; ++i) {
  62. rtEvent_t event = nullptr;
  63. rt_ret = rtEventCreate(&event);
  64. if (rt_ret != RT_ERROR_NONE) {
  65. GELOGE(RT_FAILED, "rtEventCreate failed. rt_ret = %d, index = %u", static_cast<int>(rt_ret), i);
  66. return RT_FAILED;
  67. }
  68. event_list_.emplace_back(event);
  69. }
  70. // Create rt label
  71. for (uint32_t i = 0; i < label_num; ++i) {
  72. rtLabel_t label = nullptr;
  73. rt_ret = rtLabelCreate(&label);
  74. if (rt_ret != RT_ERROR_NONE) {
  75. GELOGE(RT_FAILED, "rtLabelCreate failed. rt_ret = %d, index = %u", static_cast<int>(rt_ret), i);
  76. return RT_FAILED;
  77. }
  78. label_list_.emplace_back(label);
  79. }
  80. return SUCCESS;
  81. }
  82. void RunContextUtil::DestroyRtModelResources() noexcept {
  83. rtError_t rt_ret;
  84. for (size_t i = 0; i < stream_list_.size(); i++) {
  85. // Unbind stream to model first
  86. (void)rtModelUnbindStream(rt_model_, stream_list_[i]);
  87. rt_ret = rtStreamDestroy(stream_list_[i]);
  88. if (rt_ret != RT_ERROR_NONE) {
  89. GELOGW("Destroy stream failed. rt_ret = %d, index = %zu.", static_cast<int>(rt_ret), i);
  90. }
  91. }
  92. stream_list_.clear();
  93. for (size_t i = 0; i < event_list_.size(); i++) {
  94. rt_ret = rtEventDestroy(event_list_[i]);
  95. if (rt_ret != RT_ERROR_NONE) {
  96. GELOGW("Destroy event failed. rt_ret = %d, index = %zu.", static_cast<int>(rt_ret), i);
  97. }
  98. }
  99. event_list_.clear();
  100. for (size_t i = 0; i < label_list_.size(); ++i) {
  101. rt_ret = rtLabelDestroy(label_list_[i]);
  102. if (rt_ret != RT_ERROR_NONE) {
  103. GELOGW("Destroy label failed. rt_ret = %d, index = %zu.", static_cast<int>(rt_ret), i);
  104. }
  105. }
  106. label_list_.clear();
  107. if (rt_model_ != nullptr) {
  108. rt_ret = rtModelDestroy(rt_model_);
  109. if (rt_ret != RT_ERROR_NONE) {
  110. GELOGW("Destroy rt model failed. rt_ret = %d.", static_cast<int>(rt_ret));
  111. }
  112. rt_model_ = nullptr;
  113. }
  114. }
  115. Status RunContextUtil::CreateRunContext(Model &model, const ComputeGraphPtr &graph, Buffer &buffer,
  116. const uint64_t session_id) {
  117. GELOGI("Begin to Create RunContext, session_id = %lu", session_id);
  118. // check params
  119. if (graph == nullptr) {
  120. GELOGE(PARAM_INVALID, "CreateRunContext param graph is null. session_id=%lu", session_id);
  121. return PARAM_INVALID;
  122. }
  123. uint32_t stream_num = 0;
  124. if (!AttrUtils::GetInt(&model, ATTR_MODEL_STREAM_NUM, stream_num)) {
  125. GELOGE(INTERNAL_ERROR, "Get stream_num attr from model_def failed. session_id=%lu", session_id);
  126. return INTERNAL_ERROR;
  127. }
  128. GELOGI("Stream_num = %u", stream_num);
  129. uint32_t event_num = 0;
  130. if (!AttrUtils::GetInt(&model, ATTR_MODEL_EVENT_NUM, event_num)) {
  131. GELOGE(INTERNAL_ERROR, "Get event_num attr from model failed. session_id=%lu", session_id);
  132. return INTERNAL_ERROR;
  133. }
  134. GELOGI("Event_num = %u", event_num);
  135. uint32_t label_num = 0;
  136. if (!AttrUtils::GetInt(&model, ATTR_MODEL_LABEL_NUM, label_num)) {
  137. GELOGE(INTERNAL_ERROR, "Get label_num attr from model failed. session_id=%lu", session_id);
  138. return INTERNAL_ERROR;
  139. }
  140. GELOGI("Label_num = %u", label_num);
  141. Status ret = CreateRtModelResources(stream_num, event_num, label_num);
  142. if (ret != SUCCESS) {
  143. GELOGE(ret, "CreateRtModelResources failed. session_id=%lu", session_id);
  144. DestroyRtModelResources();
  145. return ret;
  146. }
  147. GELOGI("CreateRunContext: data_mem_base_ = %p, weight_mem_base_ = %p, memory_size = %lu, weight_size = %lu",
  148. data_mem_base_, weight_mem_base_, data_mem_size_, weight_mem_size_);
  149. run_context_ = {rt_model_, nullptr, session_id, data_mem_size_, data_mem_base_, weight_mem_size_,
  150. weight_mem_base_, buffer, stream_list_, event_list_, label_list_};
  151. return SUCCESS;
  152. }
  153. RunContext &RunContextUtil::GetRunContext() { return run_context_; }
  154. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示