Browse Source

Fix ut.

tags/v1.5.1
zhaozhixuan 3 years ago
parent
commit
813f2fe4a2
3 changed files with 9 additions and 7 deletions
  1. +4
    -6
      ge/single_op/task/op_task.cc
  2. +3
    -1
      ge/single_op/task/op_task.h
  3. +2
    -0
      ge/single_op/task/tbe_task_builder.cc

+ 4
- 6
ge/single_op/task/op_task.cc View File

@@ -346,7 +346,7 @@ Status TbeOpTask::AllocateWorkspaces(const vector<int64_t> &workspace_sizes) {
}

Status TbeOpTask::UpdateTilingArgs(rtStream_t stream) {
size_t args_size = op_desc_->GetInputsSize() + op_desc_->GetOutputsSize() + workspaces_.size();
size_t args_size = input_num_ + output_num_ + workspaces_.size();
if (tiling_buffer_ != nullptr) {
args_size++;
}
@@ -361,12 +361,12 @@ Status TbeOpTask::UpdateTilingArgs(rtStream_t stream) {
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED;
}

args_.reset(args.release());
args_ = std::move(args);
arg_size_ = temp_size;
}

uintptr_t *arg_base = reinterpret_cast<uintptr_t *>(args_.get());
size_t arg_index = op_desc_->GetInputsSize() + op_desc_->GetOutputsSize();
size_t arg_index = input_num_ + output_num_;
for (size_t i = 0; i < workspaces_.size(); ++i) {
arg_base[arg_index++] = reinterpret_cast<uintptr_t>(workspaces_[i]);
}
@@ -382,7 +382,6 @@ Status TbeOpTask::UpdateTilingArgs(rtStream_t stream) {
}

Status TbeOpTask::SetArgIndex() {
arg_index_.clear();
const vector<bool> v_is_input_const = op_desc_->GetIsInputConst();
size_t input_index = 0;
for (size_t i = 0; i < op_desc_->GetAllInputsSize(); ++i) {
@@ -416,9 +415,8 @@ Status TbeOpTask::UpdateIoAddr(const vector<DataBuffer> &inputs, const vector<Da
arg_base[arg_index_[i]] = reinterpret_cast<uintptr_t>(inputs[i].data);
}

size_t input_size = op_desc_->GetInputsSize();
for (size_t i = 0; i < op_desc_->GetOutputsSize(); ++i) {
arg_base[input_size + i] = reinterpret_cast<uintptr_t>(outputs[i].data);
arg_base[input_num_ + i] = reinterpret_cast<uintptr_t>(outputs[i].data);
}

return SUCCESS;


+ 3
- 1
ge/single_op/task/op_task.h View File

@@ -123,7 +123,9 @@ class TbeOpTask : public OpTask {
void* handle_ = nullptr;
std::string original_kernel_key_;
std::string node_info_;
std::vector<size_t> arg_index_;
std::vector<size_t> arg_index_; // data index in args
size_t input_num_; // Include const input
size_t output_num_;
};

class AiCpuBaseTask : public OpTask {


+ 2
- 0
ge/single_op/task/tbe_task_builder.cc View File

@@ -388,6 +388,8 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam &para
task.SetStubFunc(stub_name_, stub_func);
}
GE_CHK_STATUS_RET(task.SetArgIndex(), "[Set][ArgTable] failed.");
task.input_num_ = op_desc_->GetInputsSize();
task.output_num_ = op_desc_->GetOutputsSize();

return SUCCESS;
}


Loading…
Cancel
Save