You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

command_handle.cc 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "command_handle.h"
  17. #include "runtime/base.h"
  18. #include "common/profiling/profiling_manager.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "framework/common/debug/log.h"
  21. #include "framework/common/ge_inner_error_codes.h"
  22. #include "framework/omg/omg_inner_types.h"
  23. #include "graph/load/graph_loader.h"
  24. namespace {
  25. const uint32_t kDeviceListIndex = 3;
  26. const uint32_t kCommandNum = 6;
  27. const int kMaxDevNum = 64;
  28. const std::string kDeviceNums = "devNums";
  29. const std::string kDeviceIdList = "devIdList";
  30. const std::string kProfilingInit = "prof_init";
  31. const std::string kProfilingFinalize = "prof_finalize";
  32. const std::string kProfilingStart = "prof_start";
  33. const std::string kProfilingStop = "prof_stop";
  34. const std::string kProfilingModelSubscribe = "prof_model_subscribe";
  35. const std::string kProfilingModelUnsubscribe = "prof_model_cancel_subscribe";
  36. const std::string kProfilingModelId = "modelId";
  37. enum ProfCommandHandleType {
  38. kProfCommandhandleInit = 0,
  39. kProfCommandhandleStart,
  40. kProfCommandhandleStop,
  41. kProfCommandhandleFinalize,
  42. kProfCommandhandleModelSubscribe,
  43. kProfCommandhandleModelUnsubscribe
  44. };
  45. const std::map<ProfCommandHandleType, std::string> kProfCommandTypeMap = {
  46. {kProfCommandhandleInit, kProfilingInit},
  47. {kProfCommandhandleStart, kProfilingStart},
  48. {kProfCommandhandleStop, kProfilingStop},
  49. {kProfCommandhandleFinalize, kProfilingFinalize},
  50. {kProfCommandhandleModelSubscribe, kProfilingModelSubscribe},
  51. {kProfCommandhandleModelUnsubscribe, kProfilingModelUnsubscribe}};
  52. bool IsProfTypeValid(uint32_t type) {
  53. if (type < 0 || type >= kCommandNum) {
  54. GELOGE(ge::PARAM_INVALID, "[Check][Type]Type %u is invalid", type);
  55. return false;
  56. }
  57. GELOGD("Type is %u", type);
  58. return true;
  59. }
  60. bool IsProfConfigValid(const uint32_t *deviceid_list, uint32_t device_nums) {
  61. if (deviceid_list == nullptr) {
  62. GELOGE(ge::PARAM_INVALID, "[Check][DeviceIDList]Invalid, it is nullptr");
  63. REPORT_INNER_ERROR("E19999", "Device id list is nullptr");
  64. return false;
  65. }
  66. if (device_nums == 0 || device_nums > kMaxDevNum) {
  67. GELOGE(ge::PARAM_INVALID, "[Check][DeviceNums]Invalid, device nums: %u", device_nums);
  68. REPORT_INNER_ERROR("E19999", "DeviceNums %u check invalid", device_nums);
  69. return false;
  70. }
  71. // real device num
  72. int32_t dev_count = 0;
  73. rtError_t rt_err = rtGetDeviceCount(&dev_count);
  74. if (rt_err != RT_ERROR_NONE) {
  75. GELOGE(ge::INTERNAL_ERROR, "[Get][DeviceCount]Failed, error_code %d", rt_err);
  76. REPORT_CALL_ERROR("E19999", "Get device count failed, error_code %d", rt_err);
  77. return false;
  78. }
  79. if (device_nums > static_cast<uint32_t>(dev_count)) {
  80. GELOGE(ge::PARAM_INVALID, "[Check][Param]Device num %u is not in range [1,%d]", device_nums, dev_count);
  81. REPORT_INNER_ERROR("E19999", "Device num %u check invalid, it is not in range [1,%d]", device_nums, dev_count);
  82. return false;
  83. }
  84. std::set<uint32_t> record;
  85. for (size_t i = 0; i < device_nums; ++i) {
  86. uint32_t dev_id = deviceid_list[i];
  87. if (dev_id >= static_cast<uint32_t>(dev_count)) {
  88. GELOGE(ge::PARAM_INVALID, "[Check][DeviceId]Device id %u is not in range [0,%d)", dev_id, dev_count);
  89. REPORT_CALL_ERROR("E19999", "Device id %u is not in range [0,%d)", dev_id, dev_count);
  90. return false;
  91. }
  92. if (record.count(dev_id) > 0) {
  93. GELOGE(ge::PARAM_INVALID, "[Check][DeviceId]Device id %u is duplicatedly set", dev_id);
  94. REPORT_CALL_ERROR("E19999", "Device id %u is not unique, duplicatedly set", dev_id);
  95. return false;
  96. }
  97. record.insert(dev_id);
  98. }
  99. return true;
  100. }
  101. bool TransProfConfigToParam(const rtProfCommandHandle &profCommand, vector<string> &prof_config_params) {
  102. prof_config_params.clear();
  103. prof_config_params.emplace_back(kDeviceNums);
  104. prof_config_params.emplace_back(std::to_string(profCommand.devNums));
  105. prof_config_params.emplace_back(kDeviceIdList);
  106. std::string devID = "";
  107. if (profCommand.devNums == 0) {
  108. GELOGE(ge::FAILED, "[Check][Param]The device num is invalid.");
  109. return false;
  110. }
  111. for (uint32_t i = 0; i < profCommand.devNums; i++) {
  112. devID.append(std::to_string(profCommand.devIdList[i]));
  113. if (i != profCommand.devNums - 1) {
  114. devID.append(",");
  115. }
  116. }
  117. prof_config_params.push_back(devID);
  118. return true;
  119. }
  120. ge::Status NeedUnsubscribe(ProfCommandHandleType type, bool is_subscribe, uint32_t graph_id,
  121. vector<string> &prof_params) {
  122. if (type == kProfCommandhandleModelUnsubscribe && is_subscribe) {
  123. prof_params.clear();
  124. prof_params.emplace_back(kProfilingModelId);
  125. uint32_t model_id = graph_id;
  126. if (is_subscribe) {
  127. auto &profiling_manager = ge::ProfilingManager::Instance();
  128. auto ret = profiling_manager.GetModelIdFromGraph(graph_id, model_id);
  129. if (ret != ge::SUCCESS) {
  130. GELOGE(ret, "[Get][GraphId]graph_id:%u not not found", graph_id);
  131. return ret;
  132. }
  133. }
  134. prof_params.emplace_back(std::to_string(model_id));
  135. }
  136. return ge::SUCCESS;
  137. }
  138. rtError_t NeedHandleStartEnd(ProfCommandHandleType type, rtProfCommandHandle_t *prof_config_param,
  139. std::vector<string> &prof_params) {
  140. if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) {
  141. if (!IsProfConfigValid(prof_config_param->devIdList, prof_config_param->devNums)) {
  142. return ge::FAILED;
  143. }
  144. if (!TransProfConfigToParam(*prof_config_param, prof_params)) {
  145. GELOGE(ge::PARAM_INVALID, "[Check][Param]Transfer profilerConfig to string vector failed");
  146. REPORT_CALL_ERROR("E19999", "Transfer profilerConfig to string vector failed");
  147. return ge::PARAM_INVALID;
  148. }
  149. }
  150. return ge::SUCCESS;
  151. }
  152. rtError_t NeedHandleModelSubscribe(ProfCommandHandleType type, rtProfCommandHandle_t *prof_config_param,
  153. std::vector<string> &prof_params) {
  154. if (type == kProfCommandhandleModelSubscribe) {
  155. auto &profiling_manager = ge::ProfilingManager::Instance();
  156. auto is_train = domi::GetContext().train_flag;
  157. if (is_train) {
  158. profiling_manager.SetSubscribeInfo(prof_config_param->profSwitch, prof_config_param->modelId, true);
  159. return ge::SUCCESS;
  160. }
  161. prof_params.clear();
  162. prof_params.push_back(kProfilingModelId);
  163. prof_params.push_back(std::to_string(prof_config_param->modelId));
  164. }
  165. return ge::SUCCESS;
  166. }
  167. rtError_t ExecuteCommand(ProfCommandHandleType type,
  168. std::map<ProfCommandHandleType, std::string>::const_iterator iter,
  169. rtProfCommandHandle_t *prof_config_param, std::vector<string> &prof_params) {
  170. ge::GraphLoader graph_loader;
  171. ge::Command command;
  172. command.cmd_params.clear();
  173. command.cmd_type = iter->second;
  174. command.cmd_params = prof_params;
  175. if (type != kProfCommandhandleFinalize) {
  176. command.module_index = prof_config_param->profSwitch;
  177. }
  178. GELOGI("GE commandhandle execute, Command Type: %s, data type config: 0x%lx", iter->second.c_str(),
  179. command.module_index);
  180. if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) {
  181. GELOGI("Profiling device nums:%s , deviceID:[%s]", prof_params[0].c_str(), prof_params[kDeviceListIndex].c_str());
  182. }
  183. ge::Status ret = graph_loader.CommandHandle(command);
  184. if (ret != ge::SUCCESS) {
  185. GELOGE(ret, "[Handle][Command]Handle profiling command failed, command type %s, error_code %u",
  186. iter->second.c_str(), ret);
  187. REPORT_CALL_ERROR("E19999", "Handle profiling command failed, command type %s, error_code %u",
  188. iter->second.c_str(), ret);
  189. return ge::FAILED;
  190. }
  191. GELOGI("Successfully execute profiling command type: %d, command 0x%lx.", type, command.module_index);
  192. return ge::SUCCESS;
  193. }
  194. rtError_t HandleCtrlSwitch(void *data) {
  195. auto &profiling_manager = ge::ProfilingManager::Instance();
  196. rtProfCommandHandle_t *prof_config_param = reinterpret_cast<rtProfCommandHandle_t *>(data);
  197. if (!IsProfTypeValid(prof_config_param->type)) {
  198. GELOGE(ge::PARAM_INVALID, "[Check][Param]The prof comand is invalid.");
  199. return ge::FAILED;
  200. }
  201. auto type = static_cast<ProfCommandHandleType>(prof_config_param->type);
  202. if (type != kProfCommandhandleFinalize) {
  203. GE_CHECK_NOTNULL(data);
  204. }
  205. auto iter = kProfCommandTypeMap.find(type);
  206. if (iter == kProfCommandTypeMap.end()) {
  207. GELOGE(ge::PARAM_INVALID, "[Check][Param]The prof comand type is invalid.");
  208. return ge::PARAM_INVALID;
  209. }
  210. std::vector<string> prof_params;
  211. ge::Status ret = NeedHandleStartEnd(type, prof_config_param, prof_params);
  212. if (ret != ge::SUCCESS) {
  213. return ret;
  214. }
  215. ret = NeedHandleModelSubscribe(type, prof_config_param, prof_params);
  216. if (ret != ge::SUCCESS) {
  217. return ret;
  218. }
  219. auto is_subscribe = profiling_manager.GetSubscribeInfo().is_subscribe;
  220. // GraphId is actually stored in prof_config_param
  221. auto graph_id = prof_config_param->modelId;
  222. ret = NeedUnsubscribe(type, is_subscribe, graph_id, prof_params);
  223. if (ret != ge::SUCCESS) {
  224. GELOGE(ret, "[Check][Param]graph_id:%u not not found", graph_id);
  225. REPORT_INPUT_ERROR(
  226. "E10001", std::vector<std::string>({"value", "parameter", "reason"}),
  227. std::vector<std::string>({std::to_string(graph_id), "GraphToModelMap", "graph_id does not exist!"}));
  228. return ge::FAILED;
  229. }
  230. return ExecuteCommand(type, iter, prof_config_param, prof_params);
  231. }
  232. } // namespace
  233. namespace ge {
  234. rtError_t CommandHandle(uint32_t rt_type, void *data, uint32_t len) {
  235. if (data == nullptr) {
  236. GELOGE(ge::PARAM_INVALID, "[Check][Param]The prof comand is invalid.");
  237. return ge::FAILED;
  238. }
  239. auto &profiling_manager = ge::ProfilingManager::Instance();
  240. if (rt_type == RT_PROF_CTRL_REPORTER) {
  241. profiling_manager.SetMsprofReporterCallback(reinterpret_cast<MsprofReporterCallback>(data));
  242. GELOGD("return with MsprofReporterCallback");
  243. return ge::SUCCESS;
  244. } else if (rt_type == RT_PROF_CTRL_SWITCH) {
  245. return HandleCtrlSwitch(data);
  246. }
  247. return ge::FAILED;
  248. }
  249. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示