Browse Source

fixing update arg table

tags/v1.2.0
chuxing 3 years ago
parent
commit
1ccdf2d27c
2 changed files with 13 additions and 8 deletions
  1. +10
    -6
      ge/single_op/task/op_task.cc
  2. +3
    -2
      ge/single_op/task/op_task.h

+ 10
- 6
ge/single_op/task/op_task.cc View File

@@ -112,8 +112,9 @@ Status OpTask::GetProfilingArgs(std::string &model_name, std::string &op_name, u
Status OpTask::UpdateRunInfo(const vector<GeTensorDesc> &input_desc, const vector<GeTensorDesc> &output_desc) {
return UNSUPPORTED;
}
Status OpTask::UpdateArgTable(const SingleOpModelParam &param) {
auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param);

Status OpTask::DoUpdateArgTable(const SingleOpModelParam &param, bool keep_workspace) {
auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param, keep_workspace);
auto all_addresses = BuildTaskUtils::JoinAddresses(addresses);
uintptr_t *arg_base = nullptr;
size_t arg_num = 0;
@@ -132,6 +133,10 @@ Status OpTask::UpdateArgTable(const SingleOpModelParam &param) {
return SUCCESS;
}

Status OpTask::UpdateArgTable(const SingleOpModelParam &param) {
return DoUpdateArgTable(param, true);
}

Status OpTask::LaunchKernel(const vector<GeTensorDesc> &input_desc,
const vector<DataBuffer> &input_buffers,
vector<GeTensorDesc> &output_desc,
@@ -792,10 +797,9 @@ Status AiCpuTask::LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
return SUCCESS;
}

Status AiCpuTask::UpdateArgTable(const SingleOpModelParam &param) {
auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param, false);
io_addr_host_ = BuildTaskUtils::JoinAddresses(addresses);
return SUCCESS;
Status AiCpuBaseTask::UpdateArgTable(const SingleOpModelParam &param) {
// aicpu do not have workspace, for now
return DoUpdateArgTable(param, false);
}

void AiCpuTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) {


+ 3
- 2
ge/single_op/task/op_task.h View File

@@ -54,6 +54,8 @@ class OpTask {
rtStream_t stream);

protected:
Status DoUpdateArgTable(const SingleOpModelParam &param, bool keep_workspace);

DumpProperties dump_properties_;
DumpOp dump_op_;
OpDescPtr op_desc_;
@@ -110,7 +112,7 @@ class AiCpuBaseTask : public OpTask {
AiCpuBaseTask() = default;
~AiCpuBaseTask() override;
UnknowShapeOpType GetUnknownType() const { return unknown_type_; }
Status UpdateArgTable(const SingleOpModelParam &param) override;
protected:
Status UpdateIoAddr(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs);
Status SetInputConst();
@@ -137,7 +139,6 @@ class AiCpuTask : public AiCpuBaseTask {
~AiCpuTask() override;

Status LaunchKernel(rtStream_t stream) override;
Status UpdateArgTable(const SingleOpModelParam &param) override;
void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override;

Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc,


Loading…
Cancel
Save