Browse Source

Support Aicpu Dynamic Shape

tags/v1.0.0
gukecai caifubi 4 years ago
parent
commit
03405224ad
3 changed files with 46 additions and 5 deletions
  1. +4
    -1
      inc/framework/ge_runtime/task_info.h
  2. +41
    -4
      src/ge/ge_runtime/task/aicpu_task.cc
  3. +1
    -0
      src/ge/ge_runtime/task/aicpu_task.h

+ 4
- 1
inc/framework/ge_runtime/task_info.h View File

@@ -161,12 +161,13 @@ class TbeTaskInfo : public TaskInfo {
class AicpuTaskInfo : public TaskInfo {
public:
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
const std::string &node_def, const std::vector<void *> &input_data_addrs,
const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
so_name_(so_name),
kernel_name_(kernel_name),
node_def_(node_def),
ext_info_(ext_info),
input_data_addrs_(input_data_addrs),
output_data_addrs_(output_data_addrs) {}
~AicpuTaskInfo() override {}
@@ -176,11 +177,13 @@ class AicpuTaskInfo : public TaskInfo {
const std::string &node_def() const { return node_def_; }
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
const std::string &ext_info() const { return ext_info_; }

private:
std::string so_name_;
std::string kernel_name_;
std::string node_def_;
std::string ext_info_;
std::vector<void *> input_data_addrs_;
std::vector<void *> output_data_addrs_;
};


+ 41
- 4
src/ge/ge_runtime/task/aicpu_task.cc View File

@@ -47,10 +47,36 @@ bool AicpuTask::Distribute() {
auto io_addrs_num = static_cast<uint32_t>(io_addrs.size());
auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *));
constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead);
uint32_t node_def_addr_offset = io_addr_offset + io_addrs_size;
uint32_t args_size =
sizeof(aicpu::AicpuParamHead) + io_addrs_size + static_cast<uint32_t>(task_info_->node_def().size());
aicpu::AicpuParamHead aicpu_param_head = {args_size, io_addrs_num};
uint32_t node_def_len_offset = io_addr_offset + io_addrs_size;
uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t);
uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size +
static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t);

aicpu::AicpuParamHead aicpu_param_head;
aicpu_param_head.length = args_size;
aicpu_param_head.ioAddrNum = io_addrs_num;
auto ext_info = task_info_->ext_info();
uint32_t ext_size = ext_info.size();
if (ext_info.empty()) {
aicpu_param_head.extInfoLength = 0;
aicpu_param_head.extInfoAddr = 0;
} else {
rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM);
if (flag != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", flag);
return false;
}

flag = rtMemcpy(ext_info_, ext_size, reinterpret_cast<void *>(ext_info.data()), ext_size, RT_MEMCPY_HOST_TO_DEVICE);
if (flag != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag);
return false;
}

GELOGI("ext info size:", ext_size);
aicpu_param_head.extInfoLength = ext_size;
aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_);
}

// Malloc device memory for args
rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
@@ -76,6 +102,17 @@ bool AicpuTask::Distribute() {
return false;
}
}

// Memcpy node def
auto size = task_info_->node_def().size();
rt_ret =
rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t),
reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret);
return false;
}

// Memcpy node def
rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset),
task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()),


+ 1
- 0
src/ge/ge_runtime/task/aicpu_task.h View File

@@ -41,6 +41,7 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> {
std::shared_ptr<AicpuTaskInfo> task_info_;
void *stream_;
void *args_;
void *ext_info_;
void *input_output_addr_;
};
} // namespace model_runner


Loading…
Cancel
Save