From 613e221a971dc46f3fcf0dc443e8dd1fdac6b5bf Mon Sep 17 00:00:00 2001 From: chuxing Date: Sat, 12 Dec 2020 11:18:27 +0800 Subject: [PATCH] fix dynamic single op --- .../task/aicpu_kernel_task_builder.cc | 4 + ge/single_op/task/op_task.cc | 92 +++++++------------ ge/single_op/task/op_task.h | 3 +- 3 files changed, 38 insertions(+), 61 deletions(-) diff --git a/ge/single_op/task/aicpu_kernel_task_builder.cc b/ge/single_op/task/aicpu_kernel_task_builder.cc index c676ccf8..f8a2bd1b 100755 --- a/ge/single_op/task/aicpu_kernel_task_builder.cc +++ b/ge/single_op/task/aicpu_kernel_task_builder.cc @@ -97,6 +97,10 @@ Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task, uint64_t kernel_id, cons return ret; } + if (task.GetUnknownType() == DEPEND_COMPUTE) { + GELOGE(FAILED, "AiCpuCCTask unknown type is depend compute, it's not supported now."); + return FAILED; + } auto aicpu_param_head = reinterpret_cast(task.args_.get()); if (task.ext_info_addr_dev_ != nullptr) { aicpu_param_head->extInfoLength = kernel_ext_info.size(); diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index a714c6a8..22433ec9 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -454,6 +454,29 @@ Status AiCpuBaseTask::UpdateShapeToOutputDesc(const GeShape &shape_new, GeTensor return SUCCESS; } +Status AiCpuBaseTask::UpdateIoAddr(const vector &inputs, const vector &outputs) { + uintptr_t *arg_base = nullptr; + size_t arg_num = 0; + GetIoAddr(arg_base, arg_num); + + // input number and output number was check in ValidateParams + for (size_t i = 0; i < inputs.size(); ++i) { + auto addr = inputs[i].data; + GE_CHECK_NOTNULL(addr); + GELOGD("AICpuTask input[%zu] addr = %p", i, addr); + *arg_base++ = reinterpret_cast(addr); + } + + for (size_t i = 0; i < outputs.size(); ++i) { + auto addr = outputs[i].data; + GE_CHECK_NOTNULL(addr); + GELOGD("AICpuTask output[%zu] addr = %p", i, addr); + *arg_base++ = reinterpret_cast(addr); + } + + return SUCCESS; +} + AiCpuTask::~AiCpuTask() { FreeHbm(args_); FreeHbm(io_addr_); @@ -631,40 +654,6 @@ Status AiCpuTask::UpdateShapeAndDataByResultSummary(vector &output return SUCCESS; } -Status AiCpuTask::SetIO(const vector &inputs, vector &outputs) { - vector io_addrs; - io_addrs.reserve(num_inputs_ + num_outputs_); - for (size_t i = 0; i < num_inputs_; ++i) { - GE_CHECK_NOTNULL(inputs[i]); - GELOGD("AiCpuTask input[%zu] addr = %p", i, inputs[i]); - io_addrs.emplace_back(reinterpret_cast(inputs[i])); - } - - if (unknown_type_ != DEPEND_COMPUTE) { - for (size_t i = 0; i < num_outputs_; ++i) { - GE_CHECK_NOTNULL(outputs[i]); - GELOGD("AiCpuTask output[%zu] addr = %p", i, outputs[i]); - io_addrs.emplace_back(reinterpret_cast(outputs[i])); - } - } else { - for (size_t i = 0; i < num_outputs_; ++i) { - void *summary_addr = output_summary_[i]; - io_addrs.emplace_back(reinterpret_cast(summary_addr)); - } - } - - if (!io_addrs.empty()) { - auto *dst_io_addr = const_cast(reinterpret_cast(io_addr_)); - GE_CHK_RT_RET(rtMemcpy(dst_io_addr, - sizeof(uint64_t) * io_addrs.size(), - &io_addrs[0], - sizeof(uint64_t) * io_addrs.size(), - RT_MEMCPY_HOST_TO_DEVICE)); - GE_CHECK_NOTNULL(dst_io_addr); - }; - return SUCCESS; -} - Status AiCpuTask::InitForSummaryAndCopy() { if (unknown_type_ != DEPEND_COMPUTE || num_outputs_ == 0) { GELOGI("Unknown_type is %d, output num is %d.", unknown_type_, num_outputs_); @@ -736,17 +725,17 @@ Status AiCpuTask::LaunchKernel(const std::vector &input_desc, std::vector &output_buffers, rtStream_t stream) { GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc, stream)); - std::vector inputs; - std::vector outputs; - for (auto &buffer : input_buffers) { - inputs.emplace_back(buffer.data); - } - for (auto &buffer : output_buffers) { - outputs.emplace_back(buffer.data); + if (unknown_type_ == DEPEND_COMPUTE) { + std::vector summary_buffers; + for (size_t i = 0; i < num_outputs_; ++i) { + summary_buffers.emplace_back(output_summary_[i], sizeof(aicpu::FWKAdapter::ResultSummary), false); + } + GE_CHK_STATUS_RET_NOLOG(UpdateIoAddr(input_buffers, summary_buffers)); + } else { + GE_CHK_STATUS_RET_NOLOG(UpdateIoAddr(input_buffers, output_buffers)); } - GE_CHK_STATUS_RET_NOLOG(SetIO(inputs, outputs)); - GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream)); + GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream)); if (unknown_type_ == DEPEND_SHAPE_RANGE) { GE_CHK_RT_RET(rtStreamSynchronize(stream)); GE_CHK_STATUS_RET_NOLOG(UpdateOutputShape(output_desc)); @@ -817,24 +806,9 @@ Status AiCpuCCTask::LaunchKernel(const std::vector &input_desc, std::vector &output_desc, std::vector &output_buffers, rtStream_t stream) { - GE_CHK_BOOL_RET_STATUS(unknown_type_ != DEPEND_COMPUTE, FAILED, - "AiCpuCCTask unknown type[%d] is depend compute, it's not supported now.", - unknown_type_); - GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc, stream)); - - size_t arg_index = 0; - auto *task_io_addr = reinterpret_cast(io_addr_); - GE_CHECK_NOTNULL(task_io_addr); - for (auto &input : input_buffers) { - task_io_addr[arg_index++] = reinterpret_cast(input.data); - } - for (auto &output : output_buffers) { - task_io_addr[arg_index++] = reinterpret_cast(output.data); - } - + GE_CHK_STATUS_RET_NOLOG(UpdateIoAddr(input_buffers, output_buffers)); GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream)); - if (unknown_type_ == DEPEND_SHAPE_RANGE) { GE_CHK_RT_RET(rtStreamSynchronize(stream)); GE_CHK_STATUS_RET_NOLOG(UpdateOutputShape(output_desc)); diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index 04e0def2..e2122b6f 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -112,6 +112,7 @@ class AiCpuBaseTask : public OpTask { UnknowShapeOpType GetUnknownType() const { return unknown_type_; } protected: + Status UpdateIoAddr(const std::vector &inputs, const std::vector &outputs); Status SetExtInfoAndType(const std::string &kernel_ext_info, uint64_t kernel_id); Status UpdateExtInfo(const std::vector &input_desc, @@ -145,8 +146,6 @@ class AiCpuTask : public AiCpuBaseTask { Status SetMemCopyTask(const domi::KernelExDef &kernel_def); private: - Status SetIO(const vector &inputs, vector &outputs); - // for copy task. Status InitForSummaryAndCopy(); Status UpdateShapeAndDataByResultSummary(vector &output_desc,