diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index 632cd4d8..28ec7f64 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -346,7 +346,7 @@ Status TbeOpTask::AllocateWorkspaces(const vector &workspace_sizes) { } Status TbeOpTask::UpdateTilingArgs(rtStream_t stream) { - size_t args_size = op_desc_->GetInputsSize() + op_desc_->GetOutputsSize() + workspaces_.size(); + size_t args_size = input_num_ + output_num_ + workspaces_.size(); if (tiling_buffer_ != nullptr) { args_size++; } @@ -361,12 +361,12 @@ Status TbeOpTask::UpdateTilingArgs(rtStream_t stream) { return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; } - args_.reset(args.release()); + args_ = std::move(args); arg_size_ = temp_size; } uintptr_t *arg_base = reinterpret_cast(args_.get()); - size_t arg_index = op_desc_->GetInputsSize() + op_desc_->GetOutputsSize(); + size_t arg_index = input_num_ + output_num_; for (size_t i = 0; i < workspaces_.size(); ++i) { arg_base[arg_index++] = reinterpret_cast(workspaces_[i]); } @@ -382,7 +382,6 @@ Status TbeOpTask::UpdateTilingArgs(rtStream_t stream) { } Status TbeOpTask::SetArgIndex() { - arg_index_.clear(); const vector v_is_input_const = op_desc_->GetIsInputConst(); size_t input_index = 0; for (size_t i = 0; i < op_desc_->GetAllInputsSize(); ++i) { @@ -416,9 +415,8 @@ Status TbeOpTask::UpdateIoAddr(const vector &inputs, const vector(inputs[i].data); } - size_t input_size = op_desc_->GetInputsSize(); for (size_t i = 0; i < op_desc_->GetOutputsSize(); ++i) { - arg_base[input_size + i] = reinterpret_cast(outputs[i].data); + arg_base[input_num_ + i] = reinterpret_cast(outputs[i].data); } return SUCCESS; diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index d3e8383d..f93e031a 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -123,7 +123,9 @@ class TbeOpTask : public OpTask { void* handle_ = nullptr; std::string original_kernel_key_; std::string node_info_; - std::vector arg_index_; + std::vector arg_index_; // data index in args + size_t input_num_; // Include const input + size_t output_num_; }; class AiCpuBaseTask : public OpTask { diff --git a/ge/single_op/task/tbe_task_builder.cc b/ge/single_op/task/tbe_task_builder.cc index e5206ea6..db8ecfe2 100644 --- a/ge/single_op/task/tbe_task_builder.cc +++ b/ge/single_op/task/tbe_task_builder.cc @@ -388,6 +388,8 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ task.SetStubFunc(stub_name_, stub_func); } GE_CHK_STATUS_RET(task.SetArgIndex(), "[Set][ArgTable] failed."); + task.input_num_ = op_desc_->GetInputsSize(); + task.output_num_ = op_desc_->GetOutputsSize(); return SUCCESS; }