You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rts_node_task.cc 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "hybrid/node_executor/rts/rts_node_task.h"
  17. #include "hybrid/node_executor/rts/rts_task_factory.h"
  18. #include "graph/debug/ge_attr_define.h"
  19. #include "graph/utils/tensor_utils.h"
  20. #include "graph/utils/type_utils.h"
  21. #include "graph/utils/node_utils.h"
  22. #include "common/ge/ge_util.h"
  23. #include "framework/common/op/ge_op_utils.h"
  24. namespace {
  25. constexpr uint8_t kSwitchPredIndex = 0;
  26. constexpr uint8_t kSwitchCompIndex = 1;
  27. const static std::map<rtCondition_t, std::function<bool(int64_t, int64_t)>> kCompHandle = {
  28. {RT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value == comp_value; }},
  29. {RT_NOT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value != comp_value; }},
  30. {RT_GREATER, [](int64_t pred_value, int64_t comp_value) { return pred_value > comp_value; }},
  31. {RT_GREATER_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value >= comp_value; }},
  32. {RT_LESS, [](int64_t pred_value, int64_t comp_value) { return pred_value < comp_value; }},
  33. {RT_LESS_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value <= comp_value; }},
  34. };
  35. }
  36. namespace ge {
  37. namespace hybrid {
  38. REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask);
  39. REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask);
  40. REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask);
  41. REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask);
  42. REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask);
  43. REGISTER_RTS_TASK_CREATOR(LOOPCOND, PassThroughNodeTask);
  44. REGISTER_RTS_TASK_CREATOR(NEXTITERATION, PassThroughNodeTask);
  45. REGISTER_RTS_TASK_CREATOR(REFNEXTITERATION, PassThroughNodeTask);
  46. REGISTER_RTS_TASK_CREATOR(EXIT, PassThroughNodeTask);
  47. REGISTER_RTS_TASK_CREATOR(REFEXIT, PassThroughNodeTask);
  48. REGISTER_RTS_TASK_CREATOR(LABELSET, LabelSetNodeTask);
  49. REGISTER_RTS_TASK_CREATOR(LABELGOTO, LabelGotoNodeTask);
  50. REGISTER_RTS_TASK_CREATOR(LABELGOTOEX, LabelGotoNodeTask);
  51. REGISTER_RTS_TASK_CREATOR(LABELSWITCH, LabelSwitchNodeTask);
  52. REGISTER_RTS_TASK_CREATOR(LABELSWITCHBYINDEX, LabelSwitchNodeTask);
  53. Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t index, int64_t &value) {
  54. auto tensor_value = task_context.GetInput(index);
  55. GE_CHECK_NOTNULL(tensor_value);
  56. auto tensor_desc = task_context.MutableInputDesc(index);
  57. GE_CHECK_NOTNULL(tensor_desc);
  58. auto data_type = tensor_desc->GetDataType();
  59. switch (data_type) {
  60. #define CASE_TYPE(DT, VT) \
  61. case (DT): { \
  62. VT data_val{}; \
  63. GE_CHK_STATUS_RET(tensor_value->CopyScalarValueToHost(data_val)); \
  64. value = static_cast<int64_t>(data_val); \
  65. break; \
  66. }
  67. // Just accept index data type.
  68. CASE_TYPE(DT_INT32, int32_t)
  69. CASE_TYPE(DT_INT64, int64_t)
  70. #undef CASE_TYPE
  71. default: {
  72. GELOGE(UNSUPPORTED, "Data type %s not index type.", TypeUtils::DataTypeToSerialString(data_type).c_str());
  73. return UNSUPPORTED;
  74. break;
  75. }
  76. }
  77. return SUCCESS;
  78. }
  79. Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  80. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  81. const auto &node_state = task_context.GetNodeState();
  82. node_state->RunStreamActive();
  83. if (done_callback) {
  84. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  85. }
  86. GELOGI("[%s] Done executing successfully.", task_context.GetNodeName());
  87. return SUCCESS;
  88. }
  89. Status StreamSwitchNodeTask::Init(const HybridModel &model, const NodePtr &node) {
  90. uint32_t value = 0;
  91. if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, value)) {
  92. GELOGE(INTERNAL_ERROR, "[%s] Get %s failed.", node->GetName().c_str(), ATTR_NAME_STREAM_SWITCH_COND.c_str());
  93. return INTERNAL_ERROR;
  94. }
  95. rtCondition_t cond = static_cast<rtCondition_t>(value);
  96. const auto it = kCompHandle.find(cond);
  97. if (it == kCompHandle.end()) {
  98. GELOGE(INTERNAL_ERROR, "[%s] Get Condition: %u handle failed.", node->GetName().c_str(), value);
  99. return INTERNAL_ERROR;
  100. }
  101. comp_func_ = it->second;
  102. GELOGD("[%s] Done initialization successfully, condition is %u.", node->GetName().c_str(), value);
  103. return SUCCESS;
  104. }
  105. Status StreamSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  106. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  107. GE_CHECK_NOTNULL(comp_func_);
  108. int64_t pred_value = 0;
  109. GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchPredIndex, pred_value));
  110. int64_t comp_value = 0;
  111. GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchCompIndex, comp_value));
  112. bool switch_idx = comp_func_(pred_value, comp_value);
  113. auto node_state = task_context.GetNodeState();
  114. node_state->SetSwitchIndex(static_cast<int>(switch_idx));
  115. if (done_callback) {
  116. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  117. }
  118. GELOGI("[%s] Done executing successfully, pred value: %ld, comp value: %ld, switch index: %d.",
  119. task_context.GetNodeName(), pred_value, comp_value, static_cast<int>(switch_idx));
  120. return SUCCESS;
  121. }
  122. Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  123. int index = task_context.GetNodeState()->GetMergeIndex();
  124. GELOGD("[%s] Start to execute, merge index: %d.", task_context.GetNodeName(), index);
  125. if (index < 0 || index >= task_context.NumInputs()) {
  126. GELOGE(INTERNAL_ERROR, "[%s] Invalid merge param, inputs num: %d, merge index: %d.",
  127. task_context.GetNodeName(), task_context.NumInputs(), index);
  128. return INTERNAL_ERROR;
  129. }
  130. const auto in_x = task_context.MutableInput(index); // x
  131. GE_CHECK_NOTNULL(in_x);
  132. GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(MERGE_DATA_OUTPUT, *in_x)); // y
  133. const auto out_y = task_context.MutableOutput(MERGE_INDEX_OUTPUT); // value_index
  134. GE_CHECK_NOTNULL(out_y);
  135. if (out_y->GetSize() > 0) {
  136. GE_CHK_RT_RET(rtMemcpyAsync(out_y->MutableData(), out_y->GetSize(), &index, sizeof(index),
  137. RT_MEMCPY_HOST_TO_DEVICE_EX, task_context.GetStream()));
  138. }
  139. if (done_callback) {
  140. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  141. }
  142. task_context.GetNodeState()->SetMergeIndex(-1); // Invalidate for loop.
  143. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  144. return SUCCESS;
  145. }
  146. Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  147. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  148. const auto in_x = task_context.GetInput(0); // x
  149. GE_CHECK_NOTNULL(in_x);
  150. GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y
  151. const auto &node_state = task_context.GetNodeState();
  152. if (kNextIterationOpTypes.count(node_state->GetType()) > 0) {
  153. node_state->RunNextIteration();
  154. }
  155. if (done_callback) {
  156. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  157. }
  158. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  159. return SUCCESS;
  160. }
  161. Status LabelSetNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  162. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  163. if (done_callback) {
  164. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  165. }
  166. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  167. return UNSUPPORTED;
  168. }
  169. Status LabelGotoNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  170. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  171. if (done_callback) {
  172. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  173. }
  174. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  175. return UNSUPPORTED;
  176. }
  177. Status LabelSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  178. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  179. if (done_callback) {
  180. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  181. }
  182. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  183. return UNSUPPORTED;
  184. }
  185. } // namespace hybrid
  186. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示