From 03405224ada708caa2cc34f4083b3da37fdc39af Mon Sep 17 00:00:00 2001 From: gukecai Date: Fri, 7 Aug 2020 11:12:30 +0800 Subject: [PATCH] Support Aicpu Dynamic Shape --- inc/framework/ge_runtime/task_info.h | 5 +++- src/ge/ge_runtime/task/aicpu_task.cc | 45 +++++++++++++++++++++++++--- src/ge/ge_runtime/task/aicpu_task.h | 1 + 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/inc/framework/ge_runtime/task_info.h b/inc/framework/ge_runtime/task_info.h index 68d71870..e36c4333 100644 --- a/inc/framework/ge_runtime/task_info.h +++ b/inc/framework/ge_runtime/task_info.h @@ -161,12 +161,13 @@ class TbeTaskInfo : public TaskInfo { class AicpuTaskInfo : public TaskInfo { public: AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, - const std::string &node_def, const std::vector &input_data_addrs, + const std::string &node_def, const std::string &ext_info, const std::vector &input_data_addrs, const std::vector &output_data_addrs, bool dump_flag) : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), so_name_(so_name), kernel_name_(kernel_name), node_def_(node_def), + ext_info_(ext_info), input_data_addrs_(input_data_addrs), output_data_addrs_(output_data_addrs) {} ~AicpuTaskInfo() override {} @@ -176,11 +177,13 @@ class AicpuTaskInfo : public TaskInfo { const std::string &node_def() const { return node_def_; } const std::vector &input_data_addrs() const { return input_data_addrs_; } const std::vector &output_data_addrs() const { return output_data_addrs_; } + const std::string &ext_info() const { return ext_info_; } private: std::string so_name_; std::string kernel_name_; std::string node_def_; + std::string ext_info_; std::vector input_data_addrs_; std::vector output_data_addrs_; }; diff --git a/src/ge/ge_runtime/task/aicpu_task.cc b/src/ge/ge_runtime/task/aicpu_task.cc index 9b126ec0..332a695c 100644 --- a/src/ge/ge_runtime/task/aicpu_task.cc +++ b/src/ge/ge_runtime/task/aicpu_task.cc @@ -47,10 +47,36 @@ bool AicpuTask::Distribute() { auto io_addrs_num = static_cast(io_addrs.size()); auto io_addrs_size = static_cast(io_addrs_num * sizeof(void *)); constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); - uint32_t node_def_addr_offset = io_addr_offset + io_addrs_size; - uint32_t args_size = - sizeof(aicpu::AicpuParamHead) + io_addrs_size + static_cast(task_info_->node_def().size()); - aicpu::AicpuParamHead aicpu_param_head = {args_size, io_addrs_num}; + uint32_t node_def_len_offset = io_addr_offset + io_addrs_size; + uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t); + uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size + + static_cast(task_info_->node_def().size()) + sizeof(uint32_t); + + aicpu::AicpuParamHead aicpu_param_head; + aicpu_param_head.length = args_size; + aicpu_param_head.ioAddrNum = io_addrs_num; + auto ext_info = task_info_->ext_info(); + uint32_t ext_size = ext_info.size(); + if (ext_info.empty()) { + aicpu_param_head.extInfoLength = 0; + aicpu_param_head.extInfoAddr = 0; + } else { + rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM); + if (flag != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", flag); + return false; + } + + flag = rtMemcpy(ext_info_, ext_size, reinterpret_cast(ext_info.data()), ext_size, RT_MEMCPY_HOST_TO_DEVICE); + if (flag != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag); + return false; + } + + GELOGI("ext info size:", ext_size); + aicpu_param_head.extInfoLength = ext_size; + aicpu_param_head.extInfoAddr = reinterpret_cast(ext_info_); + } // Malloc device memory for args rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); @@ -76,6 +102,17 @@ bool AicpuTask::Distribute() { return false; } } + + // Memcpy node def + auto size = task_info_->node_def().size(); + rt_ret = + rtMemcpy(reinterpret_cast(reinterpret_cast(args_) + node_def_len_offset), sizeof(uint32_t), + reinterpret_cast(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); + return false; + } + // Memcpy node def rt_ret = rtMemcpy(reinterpret_cast(reinterpret_cast(args_) + node_def_addr_offset), task_info_->node_def().size(), reinterpret_cast(task_info_->node_def().data()), diff --git a/src/ge/ge_runtime/task/aicpu_task.h b/src/ge/ge_runtime/task/aicpu_task.h index cc21af8a..2d3c5040 100644 --- a/src/ge/ge_runtime/task/aicpu_task.h +++ b/src/ge/ge_runtime/task/aicpu_task.h @@ -41,6 +41,7 @@ class AicpuTask : public TaskRepeater { std::shared_ptr task_info_; void *stream_; void *args_; + void *ext_info_; void *input_output_addr_; }; } // namespace model_runner