diff --git a/ge/ge_runtime/task/label_goto_task.cc b/ge/ge_runtime/task/label_goto_task.cc index d357accb..ad93a98f 100644 --- a/ge/ge_runtime/task/label_goto_task.cc +++ b/ge/ge_runtime/task/label_goto_task.cc @@ -16,14 +16,12 @@ #include "ge_runtime/task/label_goto_task.h" #include "ge_runtime/task/task_factory.h" +#include "framework/common/util.h" namespace ge { namespace model_runner { LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info) - : TaskRepeater(model_context, task_info), - task_info_(task_info), - stream_(nullptr), - label_(nullptr) { + : TaskRepeater(model_context, task_info), task_info_(task_info) { if (task_info_ == nullptr) { GELOGW("task_info_ is null!"); return; @@ -42,29 +40,78 @@ LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::share label_ = label_list[label_id]; } -LabelGotoTask::~LabelGotoTask() {} +LabelGotoTask::~LabelGotoTask() { + GE_FREE_RT_LOG(label_info_); + GE_FREE_RT_LOG(index_value_); +} bool LabelGotoTask::Distribute() { GELOGI("LabelGotoTask Distribute start."); + if (!CheckParamValid()) { + return false; + } + + const std::vector label_list = { label_ }; + rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); + return false; + } + + uint64_t branch_index = 0; + rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); + return false; + } + + uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); + rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); + return false; + } + + rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); + return false; + } + + rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); + return false; + } + + GELOGI("DistributeTask end."); + return true; +} + +bool LabelGotoTask::CheckParamValid() { if (stream_ == nullptr) { GELOGE(PARAM_INVALID, "stream is null!"); return false; } + if (label_ == nullptr) { GELOGE(PARAM_INVALID, "label is null!"); return false; } - rtError_t rt_ret = rtLabelGotoEx(label_, stream_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + + if (label_info_ != nullptr) { + GELOGE(PARAM_INVALID, "label_info_ has dirty data."); + return false; + } + + if (index_value_ != nullptr) { + GELOGE(PARAM_INVALID, "index_value_ has dirty data."); return false; } - GELOGI("DistributeTask end."); return true; } REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); - } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/label_goto_task.h b/ge/ge_runtime/task/label_goto_task.h index 4fd6d1bc..addbb700 100644 --- a/ge/ge_runtime/task/label_goto_task.h +++ b/ge/ge_runtime/task/label_goto_task.h @@ -31,9 +31,13 @@ class LabelGotoTask : public TaskRepeater { bool Distribute() override; private: + bool CheckParamValid(); + std::shared_ptr task_info_; - void *stream_; - void *label_; + void *stream_{nullptr}; + void *label_{nullptr}; + void *label_info_{nullptr}; + void *index_value_{nullptr}; }; } // namespace model_runner } // namespace ge diff --git a/ge/graph/load/model_manager/task_info/label_goto_ex_task_info.cc b/ge/graph/load/model_manager/task_info/label_goto_ex_task_info.cc index 1921c85d..2d108faa 100755 --- a/ge/graph/load/model_manager/task_info/label_goto_ex_task_info.cc +++ b/ge/graph/load/model_manager/task_info/label_goto_ex_task_info.cc @@ -17,9 +17,15 @@ #include "graph/load/model_manager/task_info/label_goto_ex_task_info.h" #include "graph/load/model_manager/davinci_model.h" -#include "graph/debug/ge_attr_define.h" namespace ge { +constexpr uint8_t kGotoBranchMax = 1; + +LabelGotoExTaskInfo::~LabelGotoExTaskInfo() { + GE_FREE_RT_LOG(args_); + GE_FREE_RT_LOG(index_value_); +} + Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { GELOGI("LabelGotoExTaskInfo Init Start."); GE_CHECK_NOTNULL(davinci_model); @@ -28,7 +34,7 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da return FAILED; } - // Get LabelGoto task def + // Get LabelGotoEx task def const domi::LabelGotoExDef &label_goto = task_def.label_goto_ex(); OpDescPtr op_desc = davinci_model->GetOpByIndex(label_goto.op_index()); if (op_desc == nullptr) { @@ -48,15 +54,51 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da GELOGE(PARAM_INVALID, "LabelGotoExTaskInfo: Invalid label id:%u, label size:%zu", label_index, label_list.size()); return INTERNAL_ERROR; } - label_ = label_list[label_index]; + GE_CHECK_NOTNULL(label_list[label_index]); + vector label_used = { label_list[label_index] }; + + rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM; + GELOGI("memory_type: %u", memory_type); + args_size_ = kGotoBranchMax * sizeof(rtLabelDevInfo); + rtError_t rt_ret = rtMalloc(&args_, args_size_, memory_type); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + + rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtLabelListCpy failed, error: %#x", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } - GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_); + rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), memory_type); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + + uint64_t branch_index = 0; + rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtMemcpy failed, error: %#x", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + + GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_list[label_index]); return SUCCESS; } Status LabelGotoExTaskInfo::Distribute() { GELOGI("LabelGotoExTaskInfo Distribute Start."); - rtError_t rt_ret = rtLabelGotoEx(label_, stream_); + GE_CHECK_NOTNULL(args_); + GE_CHECK_NOTNULL(index_value_); + if (args_size_ == 0) { + GELOGE(PARAM_INVALID, "branch max: %u, args size: %u invalid.", kGotoBranchMax, args_size_); + return PARAM_INVALID; + } + + rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, kGotoBranchMax, args_, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); diff --git a/ge/graph/load/model_manager/task_info/label_goto_ex_task_info.h b/ge/graph/load/model_manager/task_info/label_goto_ex_task_info.h index 25310368..3c791e7b 100755 --- a/ge/graph/load/model_manager/task_info/label_goto_ex_task_info.h +++ b/ge/graph/load/model_manager/task_info/label_goto_ex_task_info.h @@ -22,16 +22,18 @@ namespace ge { class LabelGotoExTaskInfo : public TaskInfo { public: - LabelGotoExTaskInfo() : label_(nullptr) {} + LabelGotoExTaskInfo() = default; - ~LabelGotoExTaskInfo() override { label_ = nullptr; } + ~LabelGotoExTaskInfo() override; Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; Status Distribute() override; private: - void *label_; + void *index_value_{nullptr}; // switch index input. + void *args_{nullptr}; // label info memory. + uint32_t args_size_{0}; // label info length. }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_EX_TASK_INFO_H_ diff --git a/ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.cc b/ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.cc index c2997678..cf162f7e 100644 --- a/ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.cc +++ b/ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.cc @@ -16,20 +16,13 @@ #include "graph/load/model_manager/task_info/label_switch_by_index_task_info.h" -#include "graph/debug/ge_attr_define.h" #include "graph/load/model_manager/davinci_model.h" namespace ge { constexpr uint8_t kLabelSwitchIndexNum = 1; LabelSwitchByIndexTaskInfo::~LabelSwitchByIndexTaskInfo() { - if (args_ != nullptr) { - rtError_t ret = rtFree(args_); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); - } - } - args_ = nullptr; + GE_FREE_RT_LOG(args_); index_value_ = nullptr; } @@ -37,13 +30,12 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo GELOGI("LabelSwitchByIndexTaskInfo Init Start."); GE_CHECK_NOTNULL(davinci_model); - const vector &label_list = davinci_model->GetLabelList(); Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); if (ret != SUCCESS) { return FAILED; } - // Get LabelSwitch task def + // Get LabelSwitchByIndex task def const domi::LabelSwitchByIndexDef &label_switch = task_def.label_switch_by_index(); OpDescPtr op_desc = davinci_model->GetOpByIndex(label_switch.op_index()); if (op_desc == nullptr) { @@ -68,7 +60,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo davinci_model->DisableZeroCopy(index_value_); - std::vector label_idx_list; + vector label_idx_list; if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, label_idx_list)) { GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s Get attr %s failed.", op_desc->GetName().c_str(), ATTR_NAME_LABEL_SWITCH_LIST.c_str()); @@ -81,7 +73,8 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo return INTERNAL_ERROR; } - label_list_.resize(branch_max_, nullptr); + vector label_used(branch_max_, nullptr); + const vector &label_list = davinci_model->GetLabelList(); for (size_t idx = 0; idx < label_idx_list.size(); ++idx) { uint32_t label_id = label_idx_list[idx]; if (label_id >= label_list.size()) { @@ -90,8 +83,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo return INTERNAL_ERROR; } GE_CHECK_NOTNULL(label_list[label_id]); - - label_list_[idx] = label_list[label_id]; + label_used[idx] = label_list[label_id]; } rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM; @@ -103,7 +95,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo return RT_ERROR_TO_GE_STATUS(rt_ret); } - rt_ret = rtLabelListCpy(label_list_.data(), label_list_.size(), args_, args_size_); + rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret); @@ -125,7 +117,7 @@ Status LabelSwitchByIndexTaskInfo::Distribute() { rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, branch_max_, args_, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGI("LabelSwitchByIndexTaskInfo Distribute Success."); diff --git a/ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.h b/ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.h index 00ca0844..5a8ac05a 100644 --- a/ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.h +++ b/ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.h @@ -14,16 +14,15 @@ * limitations under the License. */ -#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ -#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ +#ifndef GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ +#define GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ #include "graph/load/model_manager/task_info/task_info.h" namespace ge { class LabelSwitchByIndexTaskInfo : public TaskInfo { public: - LabelSwitchByIndexTaskInfo() - : index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0), fixed_addr_offset_(0) {} + LabelSwitchByIndexTaskInfo() = default; ~LabelSwitchByIndexTaskInfo() override; @@ -34,12 +33,11 @@ class LabelSwitchByIndexTaskInfo : public TaskInfo { Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; private: - void *index_value_; // switch index input. - uint32_t branch_max_; // max branch count. - void *args_; // label info memory. - uint32_t args_size_; // label info length. - std::vector label_list_; - int64_t fixed_addr_offset_; + void *index_value_{nullptr}; // switch index input. + uint32_t branch_max_{0}; // max branch count. + void *args_{nullptr}; // label info memory. + uint32_t args_size_{0}; // label info length. + int64_t fixed_addr_offset_{0}; }; } // namespace ge -#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ \ No newline at end of file +#endif // GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ \ No newline at end of file diff --git a/inc/framework/common/util.h b/inc/framework/common/util.h index 525cf3ea..bcc3c99b 100644 --- a/inc/framework/common/util.h +++ b/inc/framework/common/util.h @@ -166,15 +166,6 @@ } \ } while (0) -// Check if the container is empty -#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ - do { \ - if (vector.empty()) { \ - DOMI_LOGE("param[%s] is empty!", #vector); \ - return ge::FAILED; \ - } \ - } while (0) - // Check if the value on the left is greater than or equal to the value on the right #define GE_CHECK_GE(lhs, rhs) \ do { \ @@ -209,6 +200,17 @@ } \ } while (0) +#define GE_FREE_RT_LOG(addr) \ + do { \ + if (addr != nullptr) { \ + rtError_t error = rtFree(addr); \ + if (error != RT_ERROR_NONE) { \ + GELOGE(RT_FAILED, "Call rtFree failed, error: %#x", error); \ + } \ + addr = nullptr; \ + } \ + } while (0) + /** * @ingroup domi_common * @brief version of om.proto file