Browse Source

Add SetIoAddrs for UpdateArgs.

tags/v1.2.0
zhangxiaokun 3 years ago
parent
commit
9573011b2b
4 changed files with 8 additions and 10 deletions
  1. +1
    -1
      ge/graph/load/new_model_manager/davinci_model.h
  2. +1
    -1
      ge/graph/load/new_model_manager/task_info/hccl_task_info.cc
  3. +1
    -2
      ge/graph/load/new_model_manager/task_info/kernel_task_info.cc
  4. +5
    -6
      ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc

+ 1
- 1
ge/graph/load/new_model_manager/davinci_model.h View File

@@ -503,7 +503,7 @@ class DavinciModel {
void *cur_args = static_cast<char *>(args_) + offset;
return cur_args;
}
void SetTotalIOAddrs(vector<void *> &io_addrs) {
void SetTotalIOAddrs(const vector<void *> &io_addrs) {
total_io_addrs_.insert(total_io_addrs_.end(), io_addrs.begin(), io_addrs.end());
}
void SetHybridArgsSize(uint32_t args_size) { total_hybrid_args_size_ += args_size; }


+ 1
- 1
ge/graph/load/new_model_manager/task_info/hccl_task_info.cc View File

@@ -230,7 +230,7 @@ Status HcclTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *
return SUCCESS;
}

void HcclTaskInfo::SetIoAddrs() {
void HcclTaskInfo::SetIoAddrs(const OpDescPtr &op_desc) {
const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam();
const auto input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc);
const auto output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc);


+ 1
- 2
ge/graph/load/new_model_manager/task_info/kernel_task_info.cc View File

@@ -218,7 +218,7 @@ uint32_t KernelTaskInfo::GetDumpFlag() {
}

Status KernelTaskInfo::SuperKernelLaunch() {
SuperKernelTaskInfo &skt_info = davinci_model_->GetSuperKernelTaskInfo();
const SuperKernelTaskInfo &skt_info = davinci_model_->GetSuperKernelTaskInfo();
if (skt_info.kernel_list.empty()) {
GELOGI("SuperKernelLaunch: Skt_kernel_list has no task, just return");
return SUCCESS;
@@ -448,7 +448,6 @@ void KernelTaskInfo::SetIoAddrs(const OpDescPtr &op_desc) {

Status KernelTaskInfo::UpdateArgs() {
GELOGI("KernelTaskInfo::UpdateArgs in.");

if (kernel_type_ == ccKernelType::TE) {
davinci_model_->SetTotalIOAddrs(io_addrs_);
} else if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) {


+ 5
- 6
ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc View File

@@ -45,7 +45,7 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da
dst_ = reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(src_) + sizeof(void *));
// for zero copy
kind_ = RT_MEMCPY_ADDR_DEVICE_TO_DEVICE;
GE_CHK_STATUS_RET(SetIoAddrs(op_desc, memcpy_async), "Set addr failed");
GE_CHK_STATUS_RET(SetIoAddrs(op_desc, memcpy_async), "Set addrs failed");
GELOGI("MemcpyAsyncTaskInfo op name %s, src_ %p, dst_ %p, args_offset %u.",
op_desc->GetName().c_str(), src_, dst_, args_offset_);
return SUCCESS;
@@ -75,8 +75,7 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da

davinci_model_->DisableZeroCopy(src_);
davinci_model_->DisableZeroCopy(dst_);
GE_CHK_STATUS_RET(SetIoAddrs(op_desc, memcpy_async), "Set addr failed");

GE_CHK_STATUS_RET(SetIoAddrs(op_desc, memcpy_async), "Set addrs failed");
GELOGI("MemcpyAsyncTaskInfo Init Success, logic[0x%lx, 0x%lx], src:%p, dst:%p, max:%lu, count:%lu",
memcpy_async.src(), memcpy_async.dst(), src_, dst_, dst_max_, count_);
return SUCCESS;
@@ -118,7 +117,7 @@ Status MemcpyAsyncTaskInfo::CalculateArgs(const domi::TaskDef &task_def, Davinci

Status MemcpyAsyncTaskInfo::SetIoAddrs(const OpDescPtr &op_desc, const domi::MemcpyAsyncDef &memcpy_async) {
uint8_t *src = nullptr;
Status ret = ModelUtils::GetRtAddress(davinci_model_->GetRuntimeParam(), memcpy_async_.src(), src);
Status ret = ModelUtils::GetRtAddress(davinci_model_->GetRuntimeParam(), memcpy_async.src(), src);
if (ret != SUCCESS) {
return ret;
}
@@ -129,11 +128,11 @@ Status MemcpyAsyncTaskInfo::SetIoAddrs(const OpDescPtr &op_desc, const domi::Mem
io_addrs_.emplace_back(fixed_addr);
} else {
uint8_t *dst = nullptr;
ret = ModelUtils::GetRtAddress(davinci_model_->GetRuntimeParam(), memcpy_async_.dst(), dst);
ret = ModelUtils::GetRtAddress(davinci_model_->GetRuntimeParam(), memcpy_async.dst(), dst);
if (ret != SUCCESS) {
return ret;
}
io_addrs_.emplace_back(reinterpret_cast<void *>(dst_));
io_addrs_.emplace_back(reinterpret_cast<void *>(dst));
}

return SUCCESS;


Loading…
Cancel
Save