|
|
@@ -20,6 +20,7 @@ |
|
|
|
#include "graph/attr_value.h" |
|
|
|
#include "graph/debug/ge_attr_define.h" |
|
|
|
#include "graph/manager/util/hcom_util.h" |
|
|
|
#include "graph/runtime_inference_context.h" |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
#include "graph/types.h" |
|
|
|
#include "hccl/hcom.h" |
|
|
@@ -176,8 +177,61 @@ Status RdmaNodeTask::Init(TaskContext &context) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status RdmaNodeTask::SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, |
|
|
|
vector<HcomRemoteAccessAddrInfo> &addr_infos) { |
|
|
|
Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { |
|
|
|
RuntimeInferenceContext *ctx = nullptr; |
|
|
|
GE_CHK_STATUS_RET( |
|
|
|
RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); |
|
|
|
|
|
|
|
ge::Tensor remote_tensor; |
|
|
|
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); |
|
|
|
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); |
|
|
|
if (data == nullptr) { |
|
|
|
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { |
|
|
|
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); |
|
|
|
return SUCCESS; |
|
|
|
} else { |
|
|
|
REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", |
|
|
|
context.GetNodeItem().NodeType().c_str()); |
|
|
|
GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", |
|
|
|
context.GetNodeItem().NodeType().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); |
|
|
|
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" |
|
|
|
"and shape dims back:%zu not equal expect:%zu, node:%s(%s)", |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, |
|
|
|
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); |
|
|
|
GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," |
|
|
|
"number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, |
|
|
|
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { |
|
|
|
size_t remote_size = 0; |
|
|
|
for (auto idx = 0; idx < dims.front(); ++idx) { |
|
|
|
FMK_INT64_MULCHECK(idx, kVarTableRowCnt); |
|
|
|
auto line_idx = idx * kVarTableRowCnt; |
|
|
|
remote_size += data[line_idx + kVarTableIdxLen]; |
|
|
|
} |
|
|
|
auto allocator = NpuMemoryAllocator::GetAllocator(); |
|
|
|
GE_CHECK_NOTNULL(allocator); |
|
|
|
AllocationAttr attr; |
|
|
|
attr.SetMemType(RDMA_HBM); |
|
|
|
for (auto i = 0; i < context.NumOutputs(); ++i) { |
|
|
|
GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); |
|
|
|
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); |
|
|
|
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); |
|
|
|
} |
|
|
|
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { |
|
|
|
AllocationAttr attr; |
|
|
|
attr.SetMemType(RDMA_HBM); |
|
|
|
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) |
|
|
|
} |
|
|
|
|
|
|
|
TensorValue *tv; |
|
|
|
if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) { |
|
|
|
tv = context.MutableOutput(local_index_); |
|
|
@@ -185,6 +239,7 @@ Status RdmaNodeTask::SetAddrInfo(TaskContext &context, RuntimeInferenceContext * |
|
|
|
tv = context.MutableInput(local_index_); |
|
|
|
} |
|
|
|
GE_CHECK_NOTNULL(tv); |
|
|
|
auto row_num = dims.front(); |
|
|
|
addr_infos.resize(row_num); |
|
|
|
if (skip_flag_) { |
|
|
|
int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset"); |
|
|
@@ -239,65 +294,6 @@ Status RdmaNodeTask::SetAddrInfo(TaskContext &context, RuntimeInferenceContext * |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { |
|
|
|
RuntimeInferenceContext *ctx = nullptr; |
|
|
|
GE_CHK_STATUS_RET( |
|
|
|
RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); |
|
|
|
|
|
|
|
ge::Tensor remote_tensor; |
|
|
|
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); |
|
|
|
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); |
|
|
|
if (data == nullptr) { |
|
|
|
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { |
|
|
|
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); |
|
|
|
return SUCCESS; |
|
|
|
} else { |
|
|
|
REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", |
|
|
|
context.GetNodeItem().NodeType().c_str()); |
|
|
|
GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", |
|
|
|
context.GetNodeItem().NodeType().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); |
|
|
|
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" |
|
|
|
"and shape dims back:%zu not equal expect:%zu, node:%s(%s)", |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, |
|
|
|
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); |
|
|
|
GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," |
|
|
|
"number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, |
|
|
|
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { |
|
|
|
size_t remote_size = 0; |
|
|
|
for (auto idx = 0; idx < dims.front(); ++idx) { |
|
|
|
FMK_INT64_MULCHECK(idx, kVarTableRowCnt); |
|
|
|
auto line_idx = idx * kVarTableRowCnt; |
|
|
|
remote_size += data[line_idx + kVarTableIdxLen]; |
|
|
|
} |
|
|
|
auto allocator = NpuMemoryAllocator::GetAllocator(); |
|
|
|
GE_CHECK_NOTNULL(allocator); |
|
|
|
AllocationAttr attr; |
|
|
|
attr.SetMemType(RDMA_HBM); |
|
|
|
for (auto i = 0; i < context.NumOutputs(); ++i) { |
|
|
|
GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); |
|
|
|
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); |
|
|
|
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); |
|
|
|
} |
|
|
|
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { |
|
|
|
AllocationAttr attr; |
|
|
|
attr.SetMemType(RDMA_HBM); |
|
|
|
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) |
|
|
|
} |
|
|
|
|
|
|
|
auto row_num = dims.front(); |
|
|
|
return SetAddrInfo(context, ctx, data, row_num, addr_infos); |
|
|
|
} |
|
|
|
|
|
|
|
Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { |
|
|
|
GELOGI("[%s] RdmaNodeTask::ExecuteAsync in.", context.GetNodeName()); |
|
|
|
auto HcomExecEnqueueRemoteAccess = |
|
|
|