Browse Source

Replace rtLabelGotoEx by rtLabelSwitchByIndex

tags/v1.2.0
zhangxiaokun 3 years ago
parent
commit
8d8786bfd2
7 changed files with 143 additions and 56 deletions
  1. +57
    -10
      ge/ge_runtime/task/label_goto_task.cc
  2. +6
    -2
      ge/ge_runtime/task/label_goto_task.h
  3. +47
    -5
      ge/graph/load/model_manager/task_info/label_goto_ex_task_info.cc
  4. +5
    -3
      ge/graph/load/model_manager/task_info/label_goto_ex_task_info.h
  5. +8
    -16
      ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.cc
  6. +9
    -11
      ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.h
  7. +11
    -9
      inc/framework/common/util.h

+ 57
- 10
ge/ge_runtime/task/label_goto_task.cc View File

@@ -16,14 +16,12 @@

#include "ge_runtime/task/label_goto_task.h"
#include "ge_runtime/task/task_factory.h"
#include "framework/common/util.h"

namespace ge {
namespace model_runner {
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info)
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), task_info_(task_info) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
@@ -42,29 +40,78 @@ LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::share
label_ = label_list[label_id];
}

LabelGotoTask::~LabelGotoTask() {}
LabelGotoTask::~LabelGotoTask() {
GE_FREE_RT_LOG(label_info_);
GE_FREE_RT_LOG(index_value_);
}

bool LabelGotoTask::Distribute() {
GELOGI("LabelGotoTask Distribute start.");
if (!CheckParamValid()) {
return false;
}

const std::vector<void *> label_list = { label_ };
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

uint64_t branch_index = 0;
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size();
rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

bool LabelGotoTask::CheckParamValid() {
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}

if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
return false;
}
rtError_t rt_ret = rtLabelGotoEx(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);

if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
return false;
}

if (index_value_ != nullptr) {
GELOGE(PARAM_INVALID, "index_value_ has dirty data.");
return false;
}

GELOGI("DistributeTask end.");
return true;
}

REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo);

} // namespace model_runner
} // namespace ge

+ 6
- 2
ge/ge_runtime/task/label_goto_task.h View File

@@ -31,9 +31,13 @@ class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> {
bool Distribute() override;

private:
bool CheckParamValid();

std::shared_ptr<LabelGotoTaskInfo> task_info_;
void *stream_;
void *label_;
void *stream_{nullptr};
void *label_{nullptr};
void *label_info_{nullptr};
void *index_value_{nullptr};
};
} // namespace model_runner
} // namespace ge


+ 47
- 5
ge/graph/load/model_manager/task_info/label_goto_ex_task_info.cc View File

@@ -17,9 +17,15 @@
#include "graph/load/model_manager/task_info/label_goto_ex_task_info.h"

#include "graph/load/model_manager/davinci_model.h"
#include "graph/debug/ge_attr_define.h"

namespace ge {
constexpr uint8_t kGotoBranchMax = 1;

LabelGotoExTaskInfo::~LabelGotoExTaskInfo() {
GE_FREE_RT_LOG(args_);
GE_FREE_RT_LOG(index_value_);
}

Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) {
GELOGI("LabelGotoExTaskInfo Init Start.");
GE_CHECK_NOTNULL(davinci_model);
@@ -28,7 +34,7 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da
return FAILED;
}

// Get LabelGoto task def
// Get LabelGotoEx task def
const domi::LabelGotoExDef &label_goto = task_def.label_goto_ex();
OpDescPtr op_desc = davinci_model->GetOpByIndex(label_goto.op_index());
if (op_desc == nullptr) {
@@ -48,15 +54,51 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da
GELOGE(PARAM_INVALID, "LabelGotoExTaskInfo: Invalid label id:%u, label size:%zu", label_index, label_list.size());
return INTERNAL_ERROR;
}
label_ = label_list[label_index];
GE_CHECK_NOTNULL(label_list[label_index]);
vector<rtLabel_t> label_used = { label_list[label_index] };

rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM;
GELOGI("memory_type: %u", memory_type);
args_size_ = kGotoBranchMax * sizeof(rtLabelDevInfo);
rtError_t rt_ret = rtMalloc(&args_, args_size_, memory_type);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtLabelListCpy failed, error: %#x", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_);
rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), memory_type);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

uint64_t branch_index = 0;
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtMemcpy failed, error: %#x", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_list[label_index]);
return SUCCESS;
}

Status LabelGotoExTaskInfo::Distribute() {
GELOGI("LabelGotoExTaskInfo Distribute Start.");
rtError_t rt_ret = rtLabelGotoEx(label_, stream_);
GE_CHECK_NOTNULL(args_);
GE_CHECK_NOTNULL(index_value_);
if (args_size_ == 0) {
GELOGE(PARAM_INVALID, "branch max: %u, args size: %u invalid.", kGotoBranchMax, args_size_);
return PARAM_INVALID;
}

rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, kGotoBranchMax, args_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);


+ 5
- 3
ge/graph/load/model_manager/task_info/label_goto_ex_task_info.h View File

@@ -22,16 +22,18 @@
namespace ge {
class LabelGotoExTaskInfo : public TaskInfo {
public:
LabelGotoExTaskInfo() : label_(nullptr) {}
LabelGotoExTaskInfo() = default;

~LabelGotoExTaskInfo() override { label_ = nullptr; }
~LabelGotoExTaskInfo() override;

Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override;

Status Distribute() override;

private:
void *label_;
void *index_value_{nullptr}; // switch index input.
void *args_{nullptr}; // label info memory.
uint32_t args_size_{0}; // label info length.
};
} // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_EX_TASK_INFO_H_

+ 8
- 16
ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.cc View File

@@ -16,20 +16,13 @@

#include "graph/load/model_manager/task_info/label_switch_by_index_task_info.h"

#include "graph/debug/ge_attr_define.h"
#include "graph/load/model_manager/davinci_model.h"

namespace ge {
constexpr uint8_t kLabelSwitchIndexNum = 1;

LabelSwitchByIndexTaskInfo::~LabelSwitchByIndexTaskInfo() {
if (args_ != nullptr) {
rtError_t ret = rtFree(args_);
if (ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret);
}
}
args_ = nullptr;
GE_FREE_RT_LOG(args_);
index_value_ = nullptr;
}

@@ -37,13 +30,12 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
GELOGI("LabelSwitchByIndexTaskInfo Init Start.");
GE_CHECK_NOTNULL(davinci_model);

const vector<rtLabel_t> &label_list = davinci_model->GetLabelList();
Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList());
if (ret != SUCCESS) {
return FAILED;
}

// Get LabelSwitch task def
// Get LabelSwitchByIndex task def
const domi::LabelSwitchByIndexDef &label_switch = task_def.label_switch_by_index();
OpDescPtr op_desc = davinci_model->GetOpByIndex(label_switch.op_index());
if (op_desc == nullptr) {
@@ -68,7 +60,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo

davinci_model->DisableZeroCopy(index_value_);

std::vector<uint32_t> label_idx_list;
vector<uint32_t> label_idx_list;
if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, label_idx_list)) {
GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s Get attr %s failed.", op_desc->GetName().c_str(),
ATTR_NAME_LABEL_SWITCH_LIST.c_str());
@@ -81,7 +73,8 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
return INTERNAL_ERROR;
}

label_list_.resize(branch_max_, nullptr);
vector<rtLabel_t> label_used(branch_max_, nullptr);
const vector<rtLabel_t> &label_list = davinci_model->GetLabelList();
for (size_t idx = 0; idx < label_idx_list.size(); ++idx) {
uint32_t label_id = label_idx_list[idx];
if (label_id >= label_list.size()) {
@@ -90,8 +83,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
return INTERNAL_ERROR;
}
GE_CHECK_NOTNULL(label_list[label_id]);

label_list_[idx] = label_list[label_id];
label_used[idx] = label_list[label_id];
}

rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM;
@@ -103,7 +95,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

rt_ret = rtLabelListCpy(label_list_.data(), label_list_.size(), args_, args_size_);
rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
@@ -125,7 +117,7 @@ Status LabelSwitchByIndexTaskInfo::Distribute() {
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, branch_max_, args_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return RT_FAILED;
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

GELOGI("LabelSwitchByIndexTaskInfo Distribute Success.");


+ 9
- 11
ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.h View File

@@ -14,16 +14,15 @@
* limitations under the License.
*/

#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_
#ifndef GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_
#define GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_

#include "graph/load/model_manager/task_info/task_info.h"

namespace ge {
class LabelSwitchByIndexTaskInfo : public TaskInfo {
public:
LabelSwitchByIndexTaskInfo()
: index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0), fixed_addr_offset_(0) {}
LabelSwitchByIndexTaskInfo() = default;

~LabelSwitchByIndexTaskInfo() override;

@@ -34,12 +33,11 @@ class LabelSwitchByIndexTaskInfo : public TaskInfo {
Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override;

private:
void *index_value_; // switch index input.
uint32_t branch_max_; // max branch count.
void *args_; // label info memory.
uint32_t args_size_; // label info length.
std::vector<rtLabel_t> label_list_;
int64_t fixed_addr_offset_;
void *index_value_{nullptr}; // switch index input.
uint32_t branch_max_{0}; // max branch count.
void *args_{nullptr}; // label info memory.
uint32_t args_size_{0}; // label info length.
int64_t fixed_addr_offset_{0};
};
} // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_
#endif // GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_

+ 11
- 9
inc/framework/common/util.h View File

@@ -166,15 +166,6 @@
} \
} while (0)

// Check if the container is empty
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \
do { \
if (vector.empty()) { \
DOMI_LOGE("param[%s] is empty!", #vector); \
return ge::FAILED; \
} \
} while (0)

// Check if the value on the left is greater than or equal to the value on the right
#define GE_CHECK_GE(lhs, rhs) \
do { \
@@ -209,6 +200,17 @@
} \
} while (0)

#define GE_FREE_RT_LOG(addr) \
do { \
if (addr != nullptr) { \
rtError_t error = rtFree(addr); \
if (error != RT_ERROR_NONE) { \
GELOGE(RT_FAILED, "Call rtFree failed, error: %#x", error); \
} \
addr = nullptr; \
} \
} while (0)

/**
* @ingroup domi_common
* @brief version of om.proto file


Loading…
Cancel
Save