From 1ccdf2d27c43c545ab48d2e2440068e5fc68f12e Mon Sep 17 00:00:00 2001 From: chuxing Date: Fri, 18 Dec 2020 17:51:04 +0800 Subject: [PATCH] fixing update arg table --- ge/single_op/task/op_task.cc | 16 ++++++++++------ ge/single_op/task/op_task.h | 5 +++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index 4f64251c..20b1354d 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -112,8 +112,9 @@ Status OpTask::GetProfilingArgs(std::string &model_name, std::string &op_name, u Status OpTask::UpdateRunInfo(const vector &input_desc, const vector &output_desc) { return UNSUPPORTED; } -Status OpTask::UpdateArgTable(const SingleOpModelParam ¶m) { - auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param); + +Status OpTask::DoUpdateArgTable(const SingleOpModelParam ¶m, bool keep_workspace) { + auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param, keep_workspace); auto all_addresses = BuildTaskUtils::JoinAddresses(addresses); uintptr_t *arg_base = nullptr; size_t arg_num = 0; @@ -132,6 +133,10 @@ Status OpTask::UpdateArgTable(const SingleOpModelParam ¶m) { return SUCCESS; } +Status OpTask::UpdateArgTable(const SingleOpModelParam ¶m) { + return DoUpdateArgTable(param, true); +} + Status OpTask::LaunchKernel(const vector &input_desc, const vector &input_buffers, vector &output_desc, @@ -792,10 +797,9 @@ Status AiCpuTask::LaunchKernel(const std::vector &input_desc, return SUCCESS; } -Status AiCpuTask::UpdateArgTable(const SingleOpModelParam ¶m) { - auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param, false); - io_addr_host_ = BuildTaskUtils::JoinAddresses(addresses); - return SUCCESS; +Status AiCpuBaseTask::UpdateArgTable(const SingleOpModelParam ¶m) { + // aicpu do not have workspace, for now + return DoUpdateArgTable(param, false); } void AiCpuTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) { diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index 761697cb..bf78557c 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -54,6 +54,8 @@ class OpTask { rtStream_t stream); protected: + Status DoUpdateArgTable(const SingleOpModelParam ¶m, bool keep_workspace); + DumpProperties dump_properties_; DumpOp dump_op_; OpDescPtr op_desc_; @@ -110,7 +112,7 @@ class AiCpuBaseTask : public OpTask { AiCpuBaseTask() = default; ~AiCpuBaseTask() override; UnknowShapeOpType GetUnknownType() const { return unknown_type_; } - + Status UpdateArgTable(const SingleOpModelParam ¶m) override; protected: Status UpdateIoAddr(const std::vector &inputs, const std::vector &outputs); Status SetInputConst(); @@ -137,7 +139,6 @@ class AiCpuTask : public AiCpuBaseTask { ~AiCpuTask() override; Status LaunchKernel(rtStream_t stream) override; - Status UpdateArgTable(const SingleOpModelParam ¶m) override; void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; Status LaunchKernel(const std::vector &input_desc,