Browse Source

Replace rtLabelGotoEx by rtLabelSwitchByIndex

tags/v1.2.0
zhangxiaokun 3 years ago
parent
commit
8d8786bfd2
7 changed files with 143 additions and 56 deletions
  1. +57
    -10
      ge/ge_runtime/task/label_goto_task.cc
  2. +6
    -2
      ge/ge_runtime/task/label_goto_task.h
  3. +47
    -5
      ge/graph/load/model_manager/task_info/label_goto_ex_task_info.cc
  4. +5
    -3
      ge/graph/load/model_manager/task_info/label_goto_ex_task_info.h
  5. +8
    -16
      ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.cc
  6. +9
    -11
      ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.h
  7. +11
    -9
      inc/framework/common/util.h

+ 57
- 10
ge/ge_runtime/task/label_goto_task.cc View File

@@ -16,14 +16,12 @@


#include "ge_runtime/task/label_goto_task.h" #include "ge_runtime/task/label_goto_task.h"
#include "ge_runtime/task/task_factory.h" #include "ge_runtime/task/task_factory.h"
#include "framework/common/util.h"


namespace ge { namespace ge {
namespace model_runner { namespace model_runner {
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info)
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), task_info_(task_info) {
if (task_info_ == nullptr) { if (task_info_ == nullptr) {
GELOGW("task_info_ is null!"); GELOGW("task_info_ is null!");
return; return;
@@ -42,29 +40,78 @@ LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::share
label_ = label_list[label_id]; label_ = label_list[label_id];
} }


LabelGotoTask::~LabelGotoTask() {}
LabelGotoTask::~LabelGotoTask() {
GE_FREE_RT_LOG(label_info_);
GE_FREE_RT_LOG(index_value_);
}


bool LabelGotoTask::Distribute() { bool LabelGotoTask::Distribute() {
GELOGI("LabelGotoTask Distribute start."); GELOGI("LabelGotoTask Distribute start.");
if (!CheckParamValid()) {
return false;
}

const std::vector<void *> label_list = { label_ };
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

uint64_t branch_index = 0;
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size();
rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

rt_ret = rtLabelSwitchByIndex(index_value_, label_list.size(), label_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

bool LabelGotoTask::CheckParamValid() {
if (stream_ == nullptr) { if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!"); GELOGE(PARAM_INVALID, "stream is null!");
return false; return false;
} }

if (label_ == nullptr) { if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!"); GELOGE(PARAM_INVALID, "label is null!");
return false; return false;
} }
rtError_t rt_ret = rtLabelGotoEx(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);

if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
return false;
}

if (index_value_ != nullptr) {
GELOGE(PARAM_INVALID, "index_value_ has dirty data.");
return false; return false;
} }


GELOGI("DistributeTask end.");
return true; return true;
} }


REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo);

} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge

+ 6
- 2
ge/ge_runtime/task/label_goto_task.h View File

@@ -31,9 +31,13 @@ class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> {
bool Distribute() override; bool Distribute() override;


private: private:
bool CheckParamValid();

std::shared_ptr<LabelGotoTaskInfo> task_info_; std::shared_ptr<LabelGotoTaskInfo> task_info_;
void *stream_;
void *label_;
void *stream_{nullptr};
void *label_{nullptr};
void *label_info_{nullptr};
void *index_value_{nullptr};
}; };
} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge


+ 47
- 5
ge/graph/load/model_manager/task_info/label_goto_ex_task_info.cc View File

@@ -17,9 +17,15 @@
#include "graph/load/model_manager/task_info/label_goto_ex_task_info.h" #include "graph/load/model_manager/task_info/label_goto_ex_task_info.h"


#include "graph/load/model_manager/davinci_model.h" #include "graph/load/model_manager/davinci_model.h"
#include "graph/debug/ge_attr_define.h"


namespace ge { namespace ge {
constexpr uint8_t kGotoBranchMax = 1;

LabelGotoExTaskInfo::~LabelGotoExTaskInfo() {
GE_FREE_RT_LOG(args_);
GE_FREE_RT_LOG(index_value_);
}

Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) {
GELOGI("LabelGotoExTaskInfo Init Start."); GELOGI("LabelGotoExTaskInfo Init Start.");
GE_CHECK_NOTNULL(davinci_model); GE_CHECK_NOTNULL(davinci_model);
@@ -28,7 +34,7 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da
return FAILED; return FAILED;
} }


// Get LabelGoto task def
// Get LabelGotoEx task def
const domi::LabelGotoExDef &label_goto = task_def.label_goto_ex(); const domi::LabelGotoExDef &label_goto = task_def.label_goto_ex();
OpDescPtr op_desc = davinci_model->GetOpByIndex(label_goto.op_index()); OpDescPtr op_desc = davinci_model->GetOpByIndex(label_goto.op_index());
if (op_desc == nullptr) { if (op_desc == nullptr) {
@@ -48,15 +54,51 @@ Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da
GELOGE(PARAM_INVALID, "LabelGotoExTaskInfo: Invalid label id:%u, label size:%zu", label_index, label_list.size()); GELOGE(PARAM_INVALID, "LabelGotoExTaskInfo: Invalid label id:%u, label size:%zu", label_index, label_list.size());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
label_ = label_list[label_index];
GE_CHECK_NOTNULL(label_list[label_index]);
vector<rtLabel_t> label_used = { label_list[label_index] };

rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM;
GELOGI("memory_type: %u", memory_type);
args_size_ = kGotoBranchMax * sizeof(rtLabelDevInfo);
rtError_t rt_ret = rtMalloc(&args_, args_size_, memory_type);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtLabelListCpy failed, error: %#x", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}


GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_);
rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), memory_type);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtMalloc failed, error: %#x", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

uint64_t branch_index = 0;
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &branch_index, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtMemcpy failed, error: %#x", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_list[label_index]);
return SUCCESS; return SUCCESS;
} }


Status LabelGotoExTaskInfo::Distribute() { Status LabelGotoExTaskInfo::Distribute() {
GELOGI("LabelGotoExTaskInfo Distribute Start."); GELOGI("LabelGotoExTaskInfo Distribute Start.");
rtError_t rt_ret = rtLabelGotoEx(label_, stream_);
GE_CHECK_NOTNULL(args_);
GE_CHECK_NOTNULL(index_value_);
if (args_size_ == 0) {
GELOGE(PARAM_INVALID, "branch max: %u, args size: %u invalid.", kGotoBranchMax, args_size_);
return PARAM_INVALID;
}

rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, kGotoBranchMax, args_, stream_);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret);


+ 5
- 3
ge/graph/load/model_manager/task_info/label_goto_ex_task_info.h View File

@@ -22,16 +22,18 @@
namespace ge { namespace ge {
class LabelGotoExTaskInfo : public TaskInfo { class LabelGotoExTaskInfo : public TaskInfo {
public: public:
LabelGotoExTaskInfo() : label_(nullptr) {}
LabelGotoExTaskInfo() = default;


~LabelGotoExTaskInfo() override { label_ = nullptr; }
~LabelGotoExTaskInfo() override;


Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override;


Status Distribute() override; Status Distribute() override;


private: private:
void *label_;
void *index_value_{nullptr}; // switch index input.
void *args_{nullptr}; // label info memory.
uint32_t args_size_{0}; // label info length.
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_EX_TASK_INFO_H_ #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_EX_TASK_INFO_H_

+ 8
- 16
ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.cc View File

@@ -16,20 +16,13 @@


#include "graph/load/model_manager/task_info/label_switch_by_index_task_info.h" #include "graph/load/model_manager/task_info/label_switch_by_index_task_info.h"


#include "graph/debug/ge_attr_define.h"
#include "graph/load/model_manager/davinci_model.h" #include "graph/load/model_manager/davinci_model.h"


namespace ge { namespace ge {
constexpr uint8_t kLabelSwitchIndexNum = 1; constexpr uint8_t kLabelSwitchIndexNum = 1;


LabelSwitchByIndexTaskInfo::~LabelSwitchByIndexTaskInfo() { LabelSwitchByIndexTaskInfo::~LabelSwitchByIndexTaskInfo() {
if (args_ != nullptr) {
rtError_t ret = rtFree(args_);
if (ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret);
}
}
args_ = nullptr;
GE_FREE_RT_LOG(args_);
index_value_ = nullptr; index_value_ = nullptr;
} }


@@ -37,13 +30,12 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
GELOGI("LabelSwitchByIndexTaskInfo Init Start."); GELOGI("LabelSwitchByIndexTaskInfo Init Start.");
GE_CHECK_NOTNULL(davinci_model); GE_CHECK_NOTNULL(davinci_model);


const vector<rtLabel_t> &label_list = davinci_model->GetLabelList();
Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList());
if (ret != SUCCESS) { if (ret != SUCCESS) {
return FAILED; return FAILED;
} }


// Get LabelSwitch task def
// Get LabelSwitchByIndex task def
const domi::LabelSwitchByIndexDef &label_switch = task_def.label_switch_by_index(); const domi::LabelSwitchByIndexDef &label_switch = task_def.label_switch_by_index();
OpDescPtr op_desc = davinci_model->GetOpByIndex(label_switch.op_index()); OpDescPtr op_desc = davinci_model->GetOpByIndex(label_switch.op_index());
if (op_desc == nullptr) { if (op_desc == nullptr) {
@@ -68,7 +60,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo


davinci_model->DisableZeroCopy(index_value_); davinci_model->DisableZeroCopy(index_value_);


std::vector<uint32_t> label_idx_list;
vector<uint32_t> label_idx_list;
if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, label_idx_list)) { if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, label_idx_list)) {
GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s Get attr %s failed.", op_desc->GetName().c_str(), GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s Get attr %s failed.", op_desc->GetName().c_str(),
ATTR_NAME_LABEL_SWITCH_LIST.c_str()); ATTR_NAME_LABEL_SWITCH_LIST.c_str());
@@ -81,7 +73,8 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


label_list_.resize(branch_max_, nullptr);
vector<rtLabel_t> label_used(branch_max_, nullptr);
const vector<rtLabel_t> &label_list = davinci_model->GetLabelList();
for (size_t idx = 0; idx < label_idx_list.size(); ++idx) { for (size_t idx = 0; idx < label_idx_list.size(); ++idx) {
uint32_t label_id = label_idx_list[idx]; uint32_t label_id = label_idx_list[idx];
if (label_id >= label_list.size()) { if (label_id >= label_list.size()) {
@@ -90,8 +83,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
GE_CHECK_NOTNULL(label_list[label_id]); GE_CHECK_NOTNULL(label_list[label_id]);

label_list_[idx] = label_list[label_id];
label_used[idx] = label_list[label_id];
} }


rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM; rtMemType_t memory_type = op_desc->HasAttr(ATTR_NAME_MEMORY_TYPE_RANGE) ? RT_MEMORY_TS_4G : RT_MEMORY_HBM;
@@ -103,7 +95,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo
return RT_ERROR_TO_GE_STATUS(rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret);
} }


rt_ret = rtLabelListCpy(label_list_.data(), label_list_.size(), args_, args_size_);
rt_ret = rtLabelListCpy(label_used.data(), label_used.size(), args_, args_size_);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret);
@@ -125,7 +117,7 @@ Status LabelSwitchByIndexTaskInfo::Distribute() {
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, branch_max_, args_, stream_); rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, branch_max_, args_, stream_);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return RT_FAILED;
return RT_ERROR_TO_GE_STATUS(rt_ret);
} }


GELOGI("LabelSwitchByIndexTaskInfo Distribute Success."); GELOGI("LabelSwitchByIndexTaskInfo Distribute Success.");


+ 9
- 11
ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.h View File

@@ -14,16 +14,15 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_
#ifndef GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_
#define GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_


#include "graph/load/model_manager/task_info/task_info.h" #include "graph/load/model_manager/task_info/task_info.h"


namespace ge { namespace ge {
class LabelSwitchByIndexTaskInfo : public TaskInfo { class LabelSwitchByIndexTaskInfo : public TaskInfo {
public: public:
LabelSwitchByIndexTaskInfo()
: index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0), fixed_addr_offset_(0) {}
LabelSwitchByIndexTaskInfo() = default;


~LabelSwitchByIndexTaskInfo() override; ~LabelSwitchByIndexTaskInfo() override;


@@ -34,12 +33,11 @@ class LabelSwitchByIndexTaskInfo : public TaskInfo {
Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override;


private: private:
void *index_value_; // switch index input.
uint32_t branch_max_; // max branch count.
void *args_; // label info memory.
uint32_t args_size_; // label info length.
std::vector<rtLabel_t> label_list_;
int64_t fixed_addr_offset_;
void *index_value_{nullptr}; // switch index input.
uint32_t branch_max_{0}; // max branch count.
void *args_{nullptr}; // label info memory.
uint32_t args_size_{0}; // label info length.
int64_t fixed_addr_offset_{0};
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_
#endif // GE_GRAPH_LOAD_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_

+ 11
- 9
inc/framework/common/util.h View File

@@ -166,15 +166,6 @@
} \ } \
} while (0) } while (0)


// Check if the container is empty
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \
do { \
if (vector.empty()) { \
DOMI_LOGE("param[%s] is empty!", #vector); \
return ge::FAILED; \
} \
} while (0)

// Check if the value on the left is greater than or equal to the value on the right // Check if the value on the left is greater than or equal to the value on the right
#define GE_CHECK_GE(lhs, rhs) \ #define GE_CHECK_GE(lhs, rhs) \
do { \ do { \
@@ -209,6 +200,17 @@
} \ } \
} while (0) } while (0)


#define GE_FREE_RT_LOG(addr) \
do { \
if (addr != nullptr) { \
rtError_t error = rtFree(addr); \
if (error != RT_ERROR_NONE) { \
GELOGE(RT_FAILED, "Call rtFree failed, error: %#x", error); \
} \
addr = nullptr; \
} \
} while (0)

/** /**
* @ingroup domi_common * @ingroup domi_common
* @brief version of om.proto file * @brief version of om.proto file


Loading…
Cancel
Save