diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index 92d1e325..632cd4d8 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -345,14 +345,46 @@ Status TbeOpTask::AllocateWorkspaces(const vector &workspace_sizes) { return SUCCESS; } -Status TbeOpTask::UpdateIoAddr(std::vector &args, const std::vector &inputs, - const std::vector &outputs) { - uintptr_t *arg_base = nullptr; - size_t arg_num = 0; - GetIoAddr(arg_base, arg_num); +Status TbeOpTask::UpdateTilingArgs(rtStream_t stream) { + size_t args_size = op_desc_->GetInputsSize() + op_desc_->GetOutputsSize() + workspaces_.size(); + if (tiling_buffer_ != nullptr) { + args_size++; + } + size_t temp_size = args_size * sizeof(void *); + if (arg_size_ < temp_size) { + GELOGD("Need to reset size of args_ from %zu to %zu.", arg_size_, temp_size); + std::unique_ptr args(new (std::nothrow) uint8_t[temp_size]()); + GE_CHECK_NOTNULL(args); + if (memcpy_s(args.get(), temp_size, args_.get(), arg_size_) != EOK) { + GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Update][KernelArgs] failed for [%s].", node_->GetName().c_str()); + REPORT_INNER_ERROR("E19999", "update kernel args failed for %s.", node_->GetName().c_str()); + return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; + } + args_.reset(args.release()); + arg_size_ = temp_size; + } + + uintptr_t *arg_base = reinterpret_cast(args_.get()); + size_t arg_index = op_desc_->GetInputsSize() + op_desc_->GetOutputsSize(); + for (size_t i = 0; i < workspaces_.size(); ++i) { + arg_base[arg_index++] = reinterpret_cast(workspaces_[i]); + } + + if (tiling_buffer_ != nullptr) { + GELOGD("[%s] Start to copy tiling info. size = %zu", node_->GetName().c_str(), tiling_data_.size()); + GE_CHK_RT_RET(rtMemcpyAsync(tiling_buffer_, max_tiling_size_, tiling_data_.data(), tiling_data_.size(), + RT_MEMCPY_HOST_TO_DEVICE_EX, stream)); + arg_base[arg_index] = reinterpret_cast(tiling_buffer_); + } + + return SUCCESS; +} + +Status TbeOpTask::SetArgIndex() { + arg_index_.clear(); const vector v_is_input_const = op_desc_->GetIsInputConst(); - size_t non_const_index = 0; + size_t input_index = 0; for (size_t i = 0; i < op_desc_->GetAllInputsSize(); ++i) { const GeTensorDescPtr tensor_desc = op_desc_->MutableInputDesc(static_cast(i)); if (tensor_desc == nullptr) { @@ -360,33 +392,33 @@ Status TbeOpTask::UpdateIoAddr(std::vector &args, const std::vector= arg_num) { - GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Size] Args size is %zu, but get index is %zu.", arg_num, i); - REPORT_INNER_ERROR("E19999", "[Check][Size] Args size is %zu, but get index is %zu.", arg_num, i); - return ACL_ERROR_GE_PARAM_INVALID; - } - auto addr = reinterpret_cast(arg_base[i]); - GELOGD("SingleOp: %s, Index: %zu, input is const, addr = %p", op_desc_->GetName().c_str(), i, addr); - args.emplace_back(addr); + GELOGD("SingleOp: %s, Index: %zu, input is const", op_desc_->GetName().c_str(), i); + input_index++; continue; } - if (non_const_index >= inputs.size()) { - GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Size] Input size is %zu, but get non_const_index is %zu", - inputs.size(), non_const_index); - REPORT_INNER_ERROR("E19999", "[Check][Size] Input size is %zu, but get non_const_index is %zu", - inputs.size(), non_const_index); - return ACL_ERROR_GE_PARAM_INVALID; - } - auto addr = inputs[non_const_index].data; - GELOGD("SingleOp: %s, input[%zu], addr = %p", op_desc_->GetName().c_str(), i, addr); - args.emplace_back(addr); - non_const_index++; + arg_index_.emplace_back(input_index); + input_index++; } + return SUCCESS; +} - for (size_t i = 0; i < outputs.size(); ++i) { - auto addr = outputs[i].data; - GELOGD("SingleOp: %s, output[%zu] addr = %p", op_desc_->GetName().c_str(), i, addr); - args.emplace_back(addr); +Status TbeOpTask::UpdateIoAddr(const vector &inputs, const vector &outputs) { + if (arg_index_.size() != inputs.size()) { + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Size] Args size is %zu, but get input size is %zu.", + arg_index_.size(), inputs.size()); + REPORT_INNER_ERROR("E19999", "[Check][Size] Args size is %zu, but get input size is %zu.", + arg_index_.size(), inputs.size()); + return ACL_ERROR_GE_PARAM_INVALID; + } + + uintptr_t *arg_base = reinterpret_cast(args_.get()); + for (size_t i = 0; i < arg_index_.size(); ++i) { + arg_base[arg_index_[i]] = reinterpret_cast(inputs[i].data); + } + + size_t input_size = op_desc_->GetInputsSize(); + for (size_t i = 0; i < op_desc_->GetOutputsSize(); ++i) { + arg_base[input_size + i] = reinterpret_cast(outputs[i].data); } return SUCCESS; @@ -398,37 +430,11 @@ Status TbeOpTask::LaunchKernel(const vector &input_desc, vector &output_buffers, rtStream_t stream) { GELOGD("[%s] Start to launch kernel", node_->GetName().c_str()); + GE_CHK_STATUS_RET(UpdateIoAddr(input_buffers, output_buffers), "[Update][IoAddr] failed."); GE_CHK_STATUS_RET_NOLOG(UpdateNodeByShape(input_desc, output_desc)); GE_CHK_STATUS_RET_NOLOG(UpdateRunInfo()); GE_CHK_STATUS_RET(AllocateWorkspaces(run_info_workspaces_), "[Allocate][Workspaces] failed."); - std::vector args; - GE_CHK_STATUS_RET(UpdateIoAddr(args, input_buffers, output_buffers), "[Update][IoAddr] failed."); - for (auto &buffer : workspaces_) { - args.emplace_back(buffer); - } - - if (tiling_buffer_ != nullptr) { - GELOGD("[%s] Start to copy tiling info. size = %zu", node_->GetName().c_str(), tiling_data_.size()); - GE_CHK_RT_RET(rtMemcpyAsync(tiling_buffer_, max_tiling_size_, tiling_data_.data(), tiling_data_.size(), - RT_MEMCPY_HOST_TO_DEVICE_EX, stream)); - - args.emplace_back(tiling_buffer_); - } - - GELOGD("Dst size is %zu, src size is %zu.", arg_size_, args.size() * sizeof(void *)); - // node with workspace: build can not get size of workspace, need to update arg_size_ when execute - if (arg_size_ < (args.size() * sizeof(void *))) { - size_t temp_size = args.size() * sizeof(void *); - GELOGD("Need to reset size of args_ from %zu to %zu.", arg_size_, temp_size); - args_.reset(new(std::nothrow) uint8_t[temp_size]()); - GE_CHECK_NOTNULL(args_); - arg_size_ = temp_size; - } - if (memcpy_s(args_.get(), arg_size_, args.data(), args.size() * sizeof(void *)) != EOK) { - GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Update][KernelArgs] failed for [%s].", node_->GetName().c_str()); - REPORT_INNER_ERROR("E19999", "update kernel args failed for %s.", node_->GetName().c_str()); - return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; - } + GE_CHK_STATUS_RET(UpdateTilingArgs(stream), "[Update][TilingArgs] failed."); GELOGD("[%s] Start to invoke rtKernelLaunch", node_->GetName().c_str()); GE_CHK_STATUS_RET(DoLaunchKernel(stream), "Failed to do launch kernel."); diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index 0cbc1a29..d3e8383d 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -85,6 +85,7 @@ class TbeOpTask : public OpTask { const OpDescPtr &op_desc, const domi::KernelDefWithHandle& kernel_def_with_handle); Status UpdateRunInfo() override; + Status SetArgIndex(); const void *GetArgs() const; size_t GetArgSize() const; @@ -100,9 +101,9 @@ class TbeOpTask : public OpTask { Status UpdateNodeByShape(const vector &input_desc, const vector &output_desc); Status AllocateWorkspaces(const std::vector &workspace_sizes); + Status UpdateTilingArgs(rtStream_t stream); Status DoLaunchKernel(rtStream_t stream); - Status UpdateIoAddr(std::vector &args, const std::vector &inputs, - const std::vector &outputs); + Status UpdateIoAddr(const vector &inputs, const vector &outputs); const void *stub_func_ = nullptr; std::unique_ptr args_; @@ -122,6 +123,7 @@ class TbeOpTask : public OpTask { void* handle_ = nullptr; std::string original_kernel_key_; std::string node_info_; + std::vector arg_index_; }; class AiCpuBaseTask : public OpTask { diff --git a/ge/single_op/task/tbe_task_builder.cc b/ge/single_op/task/tbe_task_builder.cc index c7ff13d1..e5206ea6 100644 --- a/ge/single_op/task/tbe_task_builder.cc +++ b/ge/single_op/task/tbe_task_builder.cc @@ -387,6 +387,7 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ } task.SetStubFunc(stub_name_, stub_func); } + GE_CHK_STATUS_RET(task.SetArgIndex(), "[Set][ArgTable] failed."); return SUCCESS; } diff --git a/tests/ut/ge/single_op/single_op_task_unittest.cc b/tests/ut/ge/single_op/single_op_task_unittest.cc index 472a88c3..020efc23 100644 --- a/tests/ut/ge/single_op/single_op_task_unittest.cc +++ b/tests/ut/ge/single_op/single_op_task_unittest.cc @@ -95,7 +95,7 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) { vector input_desc; vector input_buffers = { data_buffer }; vector output_desc; - vector output_buffers; + vector output_buffers = { data_buffer }; task->node_ = node; OpTilingFunc op_tiling_func = [](const TeOpParas &, const OpCompileInfo &, OpRunInfo &) -> bool {return true;}; OpTilingRegistryInterf("Add", op_tiling_func); @@ -107,8 +107,7 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) { task->max_tiling_size_ = 64; task->tiling_data_ = "tiling_data"; task->arg_size_ = 64; - uint8_t task_args{0}; - task->args_.reset(&task_args); + task->args_.reset(new (std::nothrow) uint8_t[sizeof(void *) * 3]); ASSERT_EQ(task->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_), SUCCESS); char *handle = "00"; @@ -130,17 +129,25 @@ TEST_F(UtestSingleOpTask, test_update_ioaddr) { TbeOpTask task; task.op_desc_ = op_desc; - task.args_.reset(new (std::nothrow) uint8_t[sizeof(void *) * 3]); + task.node_ = node; + ASSERT_EQ(task.SetArgIndex(), SUCCESS); + task.arg_size_ = sizeof(void *) * 4; + task.args_.reset(new (std::nothrow) uint8_t[task.arg_size_]); + task.arg_index_ = {0}; vector args; vector inputs; vector outputs; - ASSERT_EQ(task.UpdateIoAddr(args, inputs, outputs), ACL_ERROR_GE_PARAM_INVALID); - task.arg_size_ = sizeof(void *) * 3; - ASSERT_EQ(task.UpdateIoAddr(args, inputs, outputs), ACL_ERROR_GE_PARAM_INVALID); + ASSERT_EQ(task.UpdateIoAddr(inputs, outputs), ACL_ERROR_GE_PARAM_INVALID); ge::DataBuffer data_buffer; inputs = { data_buffer }; - ASSERT_EQ(task.UpdateIoAddr(args, inputs, outputs), SUCCESS); + outputs = { data_buffer }; + ASSERT_EQ(task.UpdateIoAddr(inputs, outputs), SUCCESS); + + task.tiling_buffer_ = (void *)0x0001; + task.workspaces_ = { (void *)0x0002 }; + ASSERT_EQ(task.UpdateTilingArgs(nullptr), SUCCESS); + task.tiling_buffer_ = nullptr; } diff --git a/tests/ut/ge/single_op/single_op_unittest.cc b/tests/ut/ge/single_op/single_op_unittest.cc index db3de7ec..09aac153 100644 --- a/tests/ut/ge/single_op/single_op_unittest.cc +++ b/tests/ut/ge/single_op/single_op_unittest.cc @@ -103,7 +103,7 @@ TEST_F(UtestSingleOp, test_dynamic_singleop_execute_async1) { EXPECT_EQ(desc_ptr->AddInputDesc("x", GeTensorDesc(GeShape({2}), FORMAT_NCHW)), GRAPH_SUCCESS); dynamic_single_op.op_task_->op_desc_ = desc_ptr; // UpdateRunInfo failed - EXPECT_EQ(dynamic_single_op.ExecuteAsync(input_desc, input_buffers, output_desc, output_buffers), PARAM_INVALID); + EXPECT_EQ(dynamic_single_op.ExecuteAsync(input_desc, input_buffers, output_desc, output_buffers), ACL_ERROR_GE_PARAM_INVALID); }