| @@ -16,14 +16,12 @@ | |||
| #include "ge_runtime/task/label_goto_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| #include "framework/common/util.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||
| : TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| label_(nullptr) { | |||
| : TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), task_info_(task_info) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| @@ -42,29 +40,78 @@ LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::share | |||
| label_ = label_list[label_id]; | |||
| } | |||
| LabelGotoTask::~LabelGotoTask() {} | |||
| LabelGotoTask::~LabelGotoTask() { | |||
| GE_FREE_RT_LOG(label_info_); | |||
| GE_FREE_RT_LOG(index_value_); | |||
| } | |||
| bool LabelGotoTask::Distribute() { | |||
| GELOGI("LabelGotoTask Distribute start."); | |||
| if (!CheckParamValid()) { | |||
| return false; | |||
| } | |||
| const std::vector<void *> label_list = { label_ }; | |||
| rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
| return false; | |||
| } | |||
| uint64_t branch_index = 0; | |||
| rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
| return false; | |||
| } | |||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); | |||
| rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| bool LabelGotoTask::CheckParamValid() { | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (label_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "label is null!"); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| if (label_info_ != nullptr) { | |||
| GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
| return false; | |||
| } | |||
| if (index_value_ != nullptr) { | |||
| GELOGE(PARAM_INVALID, "index_value_ has dirty data."); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -31,9 +31,13 @@ class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||
| bool Distribute() override; | |||
| private: | |||
| bool CheckParamValid(); | |||
| std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *label_; | |||
| void *stream_{nullptr}; | |||
| void *label_{nullptr}; | |||
| void *label_info_{nullptr}; | |||
| void *index_value_{nullptr}; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -17,9 +17,15 @@ | |||
| #include "graph/load/model_manager/task_info/label_goto_ex_task_info.h" | |||
| #include "graph/load/model_manager/davinci_model.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| namespace ge { | |||
| constexpr uint8_t kGotoBranchMax = 1; | |||
| LabelGotoExTaskInfo::~LabelGotoExTaskInfo() { | |||
| GE_FREE_RT_LOG(args_); | |||
| GE_FREE_RT_LOG(index_value_); | |||
| } | |||
| Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||
| GELOGI("LabelGotoExTaskInfo Init Start."); | |||
| GE_CHECK_NOTNULL(davinci_model); | |||
| @@ -28,7 +34,7 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da | |||
| return FAILED; | |||
| } | |||
| // Get LabelGoto task def | |||
| // Get LabelGotoEx task def | |||
| const domi::LabelGotoExDef &label_goto = task_def.label_goto_ex(); | |||
| OpDescPtr op_desc = davinci_model->GetOpByIndex(label_goto.op_index()); | |||
| if (op_desc == nullptr) { | |||
| @@ -48,15 +54,51 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da | |||
| GELOGE(PARAM_INVALID, "LabelGotoExTaskInfo: Invalid label id:%u, label size:%zu", label_index, label_list.size()); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| label_ = label_list[label_index]; | |||
| GE_CHECK_NOTNULL(label_list[label_index]); | |||
| vector<rtLabel_t> label_used = { label_list[label_index] }; | |||
| rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM; | |||
| GELOGI("memory_type: %u", memory_type); | |||
| args_size_ = kGotoBranchMax * sizeof(rtLabelDevInfo); | |||
| rtError_t rt_ret = rtMalloc(&args_, args_size_, memory_type); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret); | |||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
| } | |||
| rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rtLabelListCpy failed, error: %#x", rt_ret); | |||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
| } | |||
| GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_); | |||
| rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), memory_type); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret); | |||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
| } | |||
| uint64_t branch_index = 0; | |||
| rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rtMemcpy failed, error: %#x", rt_ret); | |||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
| } | |||
| GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_list[label_index]); | |||
| return SUCCESS; | |||
| } | |||
| Status LabelGotoExTaskInfo::Distribute() { | |||
| GELOGI("LabelGotoExTaskInfo Distribute Start."); | |||
| rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||
| GE_CHECK_NOTNULL(args_); | |||
| GE_CHECK_NOTNULL(index_value_); | |||
| if (args_size_ == 0) { | |||
| GELOGE(PARAM_INVALID, "branch max: %u, args size: %u invalid.", kGotoBranchMax, args_size_); | |||
| return PARAM_INVALID; | |||
| } | |||
| rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, kGotoBranchMax, args_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
| @@ -22,16 +22,18 @@ | |||
| namespace ge { | |||
| class LabelGotoExTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelGotoExTaskInfo() : label_(nullptr) {} | |||
| LabelGotoExTaskInfo() = default; | |||
| ~LabelGotoExTaskInfo() override { label_ = nullptr; } | |||
| ~LabelGotoExTaskInfo() override; | |||
| Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||
| Status Distribute() override; | |||
| private: | |||
| void *label_; | |||
| void *index_value_{nullptr}; // switch index input. | |||
| void *args_{nullptr}; // label info memory. | |||
| uint32_t args_size_{0}; // label info length. | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_EX_TASK_INFO_H_ | |||
| @@ -16,20 +16,13 @@ | |||
| #include "graph/load/model_manager/task_info/label_switch_by_index_task_info.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/load/model_manager/davinci_model.h" | |||
| namespace ge { | |||
| constexpr uint8_t kLabelSwitchIndexNum = 1; | |||
| LabelSwitchByIndexTaskInfo::~LabelSwitchByIndexTaskInfo() { | |||
| if (args_ != nullptr) { | |||
| rtError_t ret = rtFree(args_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); | |||
| } | |||
| } | |||
| args_ = nullptr; | |||
| GE_FREE_RT_LOG(args_); | |||
| index_value_ = nullptr; | |||
| } | |||
| @@ -37,13 +30,12 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo | |||
| GELOGI("LabelSwitchByIndexTaskInfo Init Start."); | |||
| GE_CHECK_NOTNULL(davinci_model); | |||
| const vector<rtLabel_t> &label_list = davinci_model->GetLabelList(); | |||
| Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); | |||
| if (ret != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| // Get LabelSwitch task def | |||
| // Get LabelSwitchByIndex task def | |||
| const domi::LabelSwitchByIndexDef &label_switch = task_def.label_switch_by_index(); | |||
| OpDescPtr op_desc = davinci_model->GetOpByIndex(label_switch.op_index()); | |||
| if (op_desc == nullptr) { | |||
| @@ -68,7 +60,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo | |||
| davinci_model->DisableZeroCopy(index_value_); | |||
| std::vector<uint32_t> label_idx_list; | |||
| vector<uint32_t> label_idx_list; | |||
| if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, label_idx_list)) { | |||
| GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s Get attr %s failed.", op_desc->GetName().c_str(), | |||
| ATTR_NAME_LABEL_SWITCH_LIST.c_str()); | |||
| @@ -81,7 +73,8 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo | |||
| return INTERNAL_ERROR; | |||
| } | |||
| label_list_.resize(branch_max_, nullptr); | |||
| vector<rtLabel_t> label_used(branch_max_, nullptr); | |||
| const vector<rtLabel_t> &label_list = davinci_model->GetLabelList(); | |||
| for (size_t idx = 0; idx < label_idx_list.size(); ++idx) { | |||
| uint32_t label_id = label_idx_list[idx]; | |||
| if (label_id >= label_list.size()) { | |||
| @@ -90,8 +83,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo | |||
| return INTERNAL_ERROR; | |||
| } | |||
| GE_CHECK_NOTNULL(label_list[label_id]); | |||
| label_list_[idx] = label_list[label_id]; | |||
| label_used[idx] = label_list[label_id]; | |||
| } | |||
| rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM; | |||
| @@ -103,7 +95,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo | |||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
| } | |||
| rt_ret = rtLabelListCpy(label_list_.data(), label_list_.size(), args_, args_size_); | |||
| rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
| @@ -125,7 +117,7 @@ Status LabelSwitchByIndexTaskInfo::Distribute() { | |||
| rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, branch_max_, args_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return RT_FAILED; | |||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
| } | |||
| GELOGI("LabelSwitchByIndexTaskInfo Distribute Success."); | |||
| @@ -14,16 +14,15 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ | |||
| #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ | |||
| #ifndef GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ | |||
| #define GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ | |||
| #include "graph/load/model_manager/task_info/task_info.h" | |||
| namespace ge { | |||
| class LabelSwitchByIndexTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSwitchByIndexTaskInfo() | |||
| : index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0), fixed_addr_offset_(0) {} | |||
| LabelSwitchByIndexTaskInfo() = default; | |||
| ~LabelSwitchByIndexTaskInfo() override; | |||
| @@ -34,12 +33,11 @@ class LabelSwitchByIndexTaskInfo : public TaskInfo { | |||
| Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||
| private: | |||
| void *index_value_; // switch index input. | |||
| uint32_t branch_max_; // max branch count. | |||
| void *args_; // label info memory. | |||
| uint32_t args_size_; // label info length. | |||
| std::vector<rtLabel_t> label_list_; | |||
| int64_t fixed_addr_offset_; | |||
| void *index_value_{nullptr}; // switch index input. | |||
| uint32_t branch_max_{0}; // max branch count. | |||
| void *args_{nullptr}; // label info memory. | |||
| uint32_t args_size_{0}; // label info length. | |||
| int64_t fixed_addr_offset_{0}; | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ | |||
| #endif // GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ | |||
| @@ -166,15 +166,6 @@ | |||
| } \ | |||
| } while (0) | |||
| // Check if the container is empty | |||
| #define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||
| do { \ | |||
| if (vector.empty()) { \ | |||
| DOMI_LOGE("param[%s] is empty!", #vector); \ | |||
| return ge::FAILED; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the value on the left is greater than or equal to the value on the right | |||
| #define GE_CHECK_GE(lhs, rhs) \ | |||
| do { \ | |||
| @@ -209,6 +200,17 @@ | |||
| } \ | |||
| } while (0) | |||
| #define GE_FREE_RT_LOG(addr) \ | |||
| do { \ | |||
| if (addr != nullptr) { \ | |||
| rtError_t error = rtFree(addr); \ | |||
| if (error != RT_ERROR_NONE) { \ | |||
| GELOGE(RT_FAILED, "Call rtFree failed, error: %#x", error); \ | |||
| } \ | |||
| addr = nullptr; \ | |||
| } \ | |||
| } while (0) | |||
| /** | |||
| * @ingroup domi_common | |||
| * @brief version of om.proto file | |||