|
|
@@ -15,15 +15,16 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "hybrid/node_executor/hccl/hccl_node_executor.h" |
|
|
|
|
|
|
|
#include "common/ge/plugin_manager.h" |
|
|
|
#include "common/math/math_util.h" |
|
|
|
#include "external/graph/attr_value.h" |
|
|
|
#include "external/graph/types.h" |
|
|
|
#include "graph/debug/ge_attr_define.h" |
|
|
|
#include "graph/manager/util/hcom_util.h" |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
#include "external/graph/types.h" |
|
|
|
#include "hybrid/executor/hybrid_execution_context.h" |
|
|
|
#include "hccl/hcom.h" |
|
|
|
#include "hybrid/executor/hybrid_execution_context.h" |
|
|
|
#include "runtime/event.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
@@ -267,14 +268,16 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess |
|
|
|
} |
|
|
|
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); |
|
|
|
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" |
|
|
|
REPORT_INNER_ERROR("E19999", |
|
|
|
"Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" |
|
|
|
"and shape dims back:%zu not equal expect:%zu, node:%s(%s)", |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, |
|
|
|
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); |
|
|
|
GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, context.GetNodeName(), |
|
|
|
context.GetNodeItem().NodeType().c_str()); |
|
|
|
GELOGE(PARAM_INVALID, |
|
|
|
"[Check][Param]Variable table shape check failed," |
|
|
|
"number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, |
|
|
|
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); |
|
|
|
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, context.GetNodeName(), |
|
|
|
context.GetNodeItem().NodeType().c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
@@ -357,8 +360,8 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do |
|
|
|
} |
|
|
|
|
|
|
|
Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { |
|
|
|
void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, |
|
|
|
¶ms.recvcounts, ¶ms.rdispls}; |
|
|
|
void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, ¶ms.recvcounts, |
|
|
|
¶ms.rdispls}; |
|
|
|
for (size_t i = 0; i < kAllToAllVInputNums; ++i) { |
|
|
|
auto addr = context.MutableInput(i); |
|
|
|
GE_CHECK_NOTNULL(addr); |
|
|
@@ -383,13 +386,14 @@ Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { |
|
|
|
} |
|
|
|
params.sendtype = iter->second; |
|
|
|
params.recvtype = iter->second; |
|
|
|
params.group = nullptr; |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams ¶ms) { |
|
|
|
void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, |
|
|
|
¶ms.recvcounts, ¶ms.rdispls}; |
|
|
|
void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, ¶ms.recvcounts, |
|
|
|
¶ms.rdispls}; |
|
|
|
for (size_t i = 0; i < kGatherAllToAllVInputNums; ++i) { |
|
|
|
auto addr = context.MutableInput(i); |
|
|
|
GE_CHECK_NOTNULL(addr); |
|
|
@@ -418,8 +422,9 @@ Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams |
|
|
|
params.recvtype = iter->second; |
|
|
|
|
|
|
|
int64_t addr_len = 0; |
|
|
|
(void) ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); |
|
|
|
(void)ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); |
|
|
|
params.addrLength = static_cast<int>(addr_len); |
|
|
|
params.group = nullptr; |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
@@ -428,7 +433,7 @@ Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::function<void() |
|
|
|
GELOGI("[%s] AllToAllNodeTask::ExecuteAsync in.", context.GetNodeName()); |
|
|
|
|
|
|
|
TaskContext *p_ctx = &context; |
|
|
|
auto callback = [p_ctx, done_callback](HcclResult status){ |
|
|
|
auto callback = [p_ctx, done_callback](HcclResult status) { |
|
|
|
if (status != HCCL_SUCCESS) { |
|
|
|
GELOGE(HCCL_E_INTERNAL, "[%s] AllToAllNodeTask execute failed.", p_ctx->GetNodeName()); |
|
|
|
p_ctx->SetStatus(FAILED); |
|
|
@@ -460,7 +465,6 @@ Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::function<void() |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
HcomGatherAllToAllVParams params; |
|
|
|
params.group = nullptr; |
|
|
|
GE_CHK_STATUS_RET(BuildGatherAllToAllParams(context, params)); |
|
|
|
HcclResult hccl_ret = HcomExecEnqueueGatherAllToAllV(params, callback); |
|
|
|
if (hccl_ret != HCCL_SUCCESS) { |
|
|
|