diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index b6a78f9e..92d1e325 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -345,6 +345,53 @@ Status TbeOpTask::AllocateWorkspaces(const vector &workspace_sizes) { return SUCCESS; } +Status TbeOpTask::UpdateIoAddr(std::vector &args, const std::vector &inputs, + const std::vector &outputs) { + uintptr_t *arg_base = nullptr; + size_t arg_num = 0; + GetIoAddr(arg_base, arg_num); + + const vector v_is_input_const = op_desc_->GetIsInputConst(); + size_t non_const_index = 0; + for (size_t i = 0; i < op_desc_->GetAllInputsSize(); ++i) { + const GeTensorDescPtr tensor_desc = op_desc_->MutableInputDesc(static_cast(i)); + if (tensor_desc == nullptr) { + GELOGD("SingleOp: %s, Index: %zu, has no input", op_desc_->GetName().c_str(), i); + continue; + } + if (i < v_is_input_const.size() && v_is_input_const[i]) { + if (i >= arg_num) { + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Size] Args size is %zu, but get index is %zu.", arg_num, i); + REPORT_INNER_ERROR("E19999", "[Check][Size] Args size is %zu, but get index is %zu.", arg_num, i); + return ACL_ERROR_GE_PARAM_INVALID; + } + auto addr = reinterpret_cast(arg_base[i]); + GELOGD("SingleOp: %s, Index: %zu, input is const, addr = %p", op_desc_->GetName().c_str(), i, addr); + args.emplace_back(addr); + continue; + } + if (non_const_index >= inputs.size()) { + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Size] Input size is %zu, but get non_const_index is %zu", + inputs.size(), non_const_index); + REPORT_INNER_ERROR("E19999", "[Check][Size] Input size is %zu, but get non_const_index is %zu", + inputs.size(), non_const_index); + return ACL_ERROR_GE_PARAM_INVALID; + } + auto addr = inputs[non_const_index].data; + GELOGD("SingleOp: %s, input[%zu], addr = %p", op_desc_->GetName().c_str(), i, addr); + args.emplace_back(addr); + non_const_index++; + } + + for (size_t i = 0; i < outputs.size(); ++i) { + auto addr = outputs[i].data; + GELOGD("SingleOp: %s, output[%zu] addr = %p", op_desc_->GetName().c_str(), i, addr); + args.emplace_back(addr); + } + + return SUCCESS; +} + Status TbeOpTask::LaunchKernel(const vector &input_desc, const vector &input_buffers, vector &output_desc, @@ -355,12 +402,7 @@ Status TbeOpTask::LaunchKernel(const vector &input_desc, GE_CHK_STATUS_RET_NOLOG(UpdateRunInfo()); GE_CHK_STATUS_RET(AllocateWorkspaces(run_info_workspaces_), "[Allocate][Workspaces] failed."); std::vector args; - for (auto &buffer : input_buffers) { - args.emplace_back(buffer.data); - } - for (auto &buffer : output_buffers) { - args.emplace_back(buffer.data); - } + GE_CHK_STATUS_RET(UpdateIoAddr(args, input_buffers, output_buffers), "[Update][IoAddr] failed."); for (auto &buffer : workspaces_) { args.emplace_back(buffer); } diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index 19320bc0..0cbc1a29 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -101,6 +101,8 @@ class TbeOpTask : public OpTask { const vector &output_desc); Status AllocateWorkspaces(const std::vector &workspace_sizes); Status DoLaunchKernel(rtStream_t stream); + Status UpdateIoAddr(std::vector &args, const std::vector &inputs, + const std::vector &outputs); const void *stub_func_ = nullptr; std::unique_ptr args_; diff --git a/tests/ut/ge/single_op/single_op_task_unittest.cc b/tests/ut/ge/single_op/single_op_task_unittest.cc index a17c9012..472a88c3 100644 --- a/tests/ut/ge/single_op/single_op_task_unittest.cc +++ b/tests/ut/ge/single_op/single_op_task_unittest.cc @@ -91,8 +91,9 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) { TbeOpTask task_tmp; TbeOpTask *task = &task_tmp; ASSERT_EQ(model.BuildKernelTask(task_def, &task), SUCCESS); + ge::DataBuffer data_buffer; vector input_desc; - vector input_buffers; + vector input_buffers = { data_buffer }; vector output_desc; vector output_buffers; task->node_ = node; @@ -110,8 +111,36 @@ TEST_F(UtestSingleOpTask, test_build_kernel_task) { task->args_.reset(&task_args); ASSERT_EQ(task->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_), SUCCESS); - char handle_tmp = '0'; - char *handle = &handle_tmp; + char *handle = "00"; task->SetHandle(handle); ASSERT_EQ(task->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_), SUCCESS); -} \ No newline at end of file +} + +TEST_F(UtestSingleOpTask, test_update_ioaddr) { + auto graph = make_shared("graph"); + auto op_desc = make_shared("Add", "Add"); + + GeTensorDesc desc; + op_desc->AddInputDesc(desc); + op_desc->AddInputDesc(desc); + op_desc->AddOutputDesc(desc); + vector is_input_const = { true, false }; + op_desc->SetIsInputConst(is_input_const); + auto node = graph->AddNode(op_desc); + + TbeOpTask task; + task.op_desc_ = op_desc; + task.args_.reset(new (std::nothrow) uint8_t[sizeof(void *) * 3]); + + vector args; + vector inputs; + vector outputs; + ASSERT_EQ(task.UpdateIoAddr(args, inputs, outputs), ACL_ERROR_GE_PARAM_INVALID); + task.arg_size_ = sizeof(void *) * 3; + ASSERT_EQ(task.UpdateIoAddr(args, inputs, outputs), ACL_ERROR_GE_PARAM_INVALID); + + ge::DataBuffer data_buffer; + inputs = { data_buffer }; + ASSERT_EQ(task.UpdateIoAddr(args, inputs, outputs), SUCCESS); +} +