@@ -16,14 +16,12 @@ | |||
#include "ge_runtime/task/label_goto_task.h" | |||
#include "ge_runtime/task/task_factory.h" | |||
#include "framework/common/util.h" | |||
namespace ge { | |||
namespace model_runner { | |||
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||
task_info_(task_info), | |||
stream_(nullptr), | |||
label_(nullptr) { | |||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), task_info_(task_info) { | |||
if (task_info_ == nullptr) { | |||
GELOGW("task_info_ is null!"); | |||
return; | |||
@@ -42,29 +40,78 @@ LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::share | |||
label_ = label_list[label_id]; | |||
} | |||
LabelGotoTask::~LabelGotoTask() {} | |||
LabelGotoTask::~LabelGotoTask() { | |||
GE_FREE_RT_LOG(label_info_); | |||
GE_FREE_RT_LOG(index_value_); | |||
} | |||
bool LabelGotoTask::Distribute() { | |||
GELOGI("LabelGotoTask Distribute start."); | |||
if (!CheckParamValid()) { | |||
return false; | |||
} | |||
const std::vector<void *> label_list = { label_ }; | |||
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
return false; | |||
} | |||
uint64_t branch_index = 0; | |||
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
return false; | |||
} | |||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); | |||
rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
return false; | |||
} | |||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
return false; | |||
} | |||
rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
return false; | |||
} | |||
GELOGI("DistributeTask end."); | |||
return true; | |||
} | |||
bool LabelGotoTask::CheckParamValid() { | |||
if (stream_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "stream is null!"); | |||
return false; | |||
} | |||
if (label_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "label is null!"); | |||
return false; | |||
} | |||
rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
if (label_info_ != nullptr) { | |||
GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
return false; | |||
} | |||
if (index_value_ != nullptr) { | |||
GELOGE(PARAM_INVALID, "index_value_ has dirty data."); | |||
return false; | |||
} | |||
GELOGI("DistributeTask end."); | |||
return true; | |||
} | |||
REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||
} // namespace model_runner | |||
} // namespace ge |
@@ -31,9 +31,13 @@ class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||
bool Distribute() override; | |||
private: | |||
bool CheckParamValid(); | |||
std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||
void *stream_; | |||
void *label_; | |||
void *stream_{nullptr}; | |||
void *label_{nullptr}; | |||
void *label_info_{nullptr}; | |||
void *index_value_{nullptr}; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
@@ -17,9 +17,15 @@ | |||
#include "graph/load/model_manager/task_info/label_goto_ex_task_info.h" | |||
#include "graph/load/model_manager/davinci_model.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
namespace ge { | |||
constexpr uint8_t kGotoBranchMax = 1; | |||
LabelGotoExTaskInfo::~LabelGotoExTaskInfo() { | |||
GE_FREE_RT_LOG(args_); | |||
GE_FREE_RT_LOG(index_value_); | |||
} | |||
Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||
GELOGI("LabelGotoExTaskInfo Init Start."); | |||
GE_CHECK_NOTNULL(davinci_model); | |||
@@ -28,7 +34,7 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da | |||
return FAILED; | |||
} | |||
// Get LabelGoto task def | |||
// Get LabelGotoEx task def | |||
const domi::LabelGotoExDef &label_goto = task_def.label_goto_ex(); | |||
OpDescPtr op_desc = davinci_model->GetOpByIndex(label_goto.op_index()); | |||
if (op_desc == nullptr) { | |||
@@ -48,15 +54,51 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da | |||
GELOGE(PARAM_INVALID, "LabelGotoExTaskInfo: Invalid label id:%u, label size:%zu", label_index, label_list.size()); | |||
return INTERNAL_ERROR; | |||
} | |||
label_ = label_list[label_index]; | |||
GE_CHECK_NOTNULL(label_list[label_index]); | |||
vector<rtLabel_t> label_used = { label_list[label_index] }; | |||
rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM; | |||
GELOGI("memory_type: %u", memory_type); | |||
args_size_ = kGotoBranchMax * sizeof(rtLabelDevInfo); | |||
rtError_t rt_ret = rtMalloc(&args_, args_size_, memory_type); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret); | |||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
} | |||
rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rtLabelListCpy failed, error: %#x", rt_ret); | |||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
} | |||
GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_); | |||
rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), memory_type); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret); | |||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
} | |||
uint64_t branch_index = 0; | |||
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rtMemcpy failed, error: %#x", rt_ret); | |||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
} | |||
GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_list[label_index]); | |||
return SUCCESS; | |||
} | |||
Status LabelGotoExTaskInfo::Distribute() { | |||
GELOGI("LabelGotoExTaskInfo Distribute Start."); | |||
rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||
GE_CHECK_NOTNULL(args_); | |||
GE_CHECK_NOTNULL(index_value_); | |||
if (args_size_ == 0) { | |||
GELOGE(PARAM_INVALID, "branch max: %u, args size: %u invalid.", kGotoBranchMax, args_size_); | |||
return PARAM_INVALID; | |||
} | |||
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, kGotoBranchMax, args_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
@@ -22,16 +22,18 @@ | |||
namespace ge { | |||
class LabelGotoExTaskInfo : public TaskInfo { | |||
public: | |||
LabelGotoExTaskInfo() : label_(nullptr) {} | |||
LabelGotoExTaskInfo() = default; | |||
~LabelGotoExTaskInfo() override { label_ = nullptr; } | |||
~LabelGotoExTaskInfo() override; | |||
Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||
Status Distribute() override; | |||
private: | |||
void *label_; | |||
void *index_value_{nullptr}; // switch index input. | |||
void *args_{nullptr}; // label info memory. | |||
uint32_t args_size_{0}; // label info length. | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_EX_TASK_INFO_H_ |
@@ -16,20 +16,13 @@ | |||
#include "graph/load/model_manager/task_info/label_switch_by_index_task_info.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/load/model_manager/davinci_model.h" | |||
namespace ge { | |||
constexpr uint8_t kLabelSwitchIndexNum = 1; | |||
LabelSwitchByIndexTaskInfo::~LabelSwitchByIndexTaskInfo() { | |||
if (args_ != nullptr) { | |||
rtError_t ret = rtFree(args_); | |||
if (ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); | |||
} | |||
} | |||
args_ = nullptr; | |||
GE_FREE_RT_LOG(args_); | |||
index_value_ = nullptr; | |||
} | |||
@@ -37,13 +30,12 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo | |||
GELOGI("LabelSwitchByIndexTaskInfo Init Start."); | |||
GE_CHECK_NOTNULL(davinci_model); | |||
const vector<rtLabel_t> &label_list = davinci_model->GetLabelList(); | |||
Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); | |||
if (ret != SUCCESS) { | |||
return FAILED; | |||
} | |||
// Get LabelSwitch task def | |||
// Get LabelSwitchByIndex task def | |||
const domi::LabelSwitchByIndexDef &label_switch = task_def.label_switch_by_index(); | |||
OpDescPtr op_desc = davinci_model->GetOpByIndex(label_switch.op_index()); | |||
if (op_desc == nullptr) { | |||
@@ -68,7 +60,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo | |||
davinci_model->DisableZeroCopy(index_value_); | |||
std::vector<uint32_t> label_idx_list; | |||
vector<uint32_t> label_idx_list; | |||
if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, label_idx_list)) { | |||
GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s Get attr %s failed.", op_desc->GetName().c_str(), | |||
ATTR_NAME_LABEL_SWITCH_LIST.c_str()); | |||
@@ -81,7 +73,8 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo | |||
return INTERNAL_ERROR; | |||
} | |||
label_list_.resize(branch_max_, nullptr); | |||
vector<rtLabel_t> label_used(branch_max_, nullptr); | |||
const vector<rtLabel_t> &label_list = davinci_model->GetLabelList(); | |||
for (size_t idx = 0; idx < label_idx_list.size(); ++idx) { | |||
uint32_t label_id = label_idx_list[idx]; | |||
if (label_id >= label_list.size()) { | |||
@@ -90,8 +83,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo | |||
return INTERNAL_ERROR; | |||
} | |||
GE_CHECK_NOTNULL(label_list[label_id]); | |||
label_list_[idx] = label_list[label_id]; | |||
label_used[idx] = label_list[label_id]; | |||
} | |||
rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM; | |||
@@ -103,7 +95,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo | |||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
} | |||
rt_ret = rtLabelListCpy(label_list_.data(), label_list_.size(), args_, args_size_); | |||
rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
@@ -125,7 +117,7 @@ Status LabelSwitchByIndexTaskInfo::Distribute() { | |||
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, branch_max_, args_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return RT_FAILED; | |||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
} | |||
GELOGI("LabelSwitchByIndexTaskInfo Distribute Success."); | |||
@@ -14,16 +14,15 @@ | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ | |||
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ | |||
#ifndef GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ | |||
#define GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ | |||
#include "graph/load/model_manager/task_info/task_info.h" | |||
namespace ge { | |||
class LabelSwitchByIndexTaskInfo : public TaskInfo { | |||
public: | |||
LabelSwitchByIndexTaskInfo() | |||
: index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0), fixed_addr_offset_(0) {} | |||
LabelSwitchByIndexTaskInfo() = default; | |||
~LabelSwitchByIndexTaskInfo() override; | |||
@@ -34,12 +33,11 @@ class LabelSwitchByIndexTaskInfo : public TaskInfo { | |||
Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||
private: | |||
void *index_value_; // switch index input. | |||
uint32_t branch_max_; // max branch count. | |||
void *args_; // label info memory. | |||
uint32_t args_size_; // label info length. | |||
std::vector<rtLabel_t> label_list_; | |||
int64_t fixed_addr_offset_; | |||
void *index_value_{nullptr}; // switch index input. | |||
uint32_t branch_max_{0}; // max branch count. | |||
void *args_{nullptr}; // label info memory. | |||
uint32_t args_size_{0}; // label info length. | |||
int64_t fixed_addr_offset_{0}; | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ | |||
#endif // GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ |
@@ -166,15 +166,6 @@ | |||
} \ | |||
} while (0) | |||
// Check if the container is empty | |||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||
do { \ | |||
if (vector.empty()) { \ | |||
DOMI_LOGE("param[%s] is empty!", #vector); \ | |||
return ge::FAILED; \ | |||
} \ | |||
} while (0) | |||
// Check if the value on the left is greater than or equal to the value on the right | |||
#define GE_CHECK_GE(lhs, rhs) \ | |||
do { \ | |||
@@ -209,6 +200,17 @@ | |||
} \ | |||
} while (0) | |||
#define GE_FREE_RT_LOG(addr) \ | |||
do { \ | |||
if (addr != nullptr) { \ | |||
rtError_t error = rtFree(addr); \ | |||
if (error != RT_ERROR_NONE) { \ | |||
GELOGE(RT_FAILED, "Call rtFree failed, error: %#x", error); \ | |||
} \ | |||
addr = nullptr; \ | |||
} \ | |||
} while (0) | |||
/** | |||
* @ingroup domi_common | |||
* @brief version of om.proto file | |||