diff --git a/ge/common/types.cc b/ge/common/types.cc index 4aa7ce01..98ae7737 100644 --- a/ge/common/types.cc +++ b/ge/common/types.cc @@ -392,6 +392,8 @@ REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); REGISTER_OPTYPE_DEFINE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); +REGISTER_OPTYPE_DEFINE(HCOMALLTOALLV, "HcomAllToAllV"); +REGISTER_OPTYPE_DEFINE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV"); REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp"); diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc index 150e8ed2..72092cd8 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc @@ -22,8 +22,8 @@ #include "graph/manager/util/hcom_util.h" #include "graph/utils/type_utils.h" #include "graph/types.h" -#include "hccl/hcom.h" #include "hybrid/executor/hybrid_execution_context.h" +#include "hccl/hcom.h" namespace ge { namespace { @@ -31,9 +31,14 @@ constexpr size_t kVarTableDims = 2; constexpr size_t kVarTableRowCnt = 3; constexpr size_t kVarTableIdxAddr = 1; constexpr size_t kVarTableIdxLen = 2; +// input anchor nums according to IR +constexpr size_t kAllToAllVInputNums = 5; +constexpr size_t kGatherAllToAllVInputNums = 4; + const std::set kRdmaReadTypes = { HCOMREMOTEREAD, HCOMREMOTEREFREAD }; const std::set kRdmaWriteTypes = { HCOMREMOTEWRITE, HCOMREMOTESCATTERWRITE }; const std::set kRdmaScatterTypes = { HCOMREMOTEREFREAD, HCOMREMOTESCATTERWRITE }; +const std::set kAllToAllTypes = {HCOMALLTOALLV, HCOMGATHERALLTOALLV}; } // namespace namespace hybrid { @@ -349,6 +354,121 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function do return SUCCESS; } +Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { + void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, + ¶ms.recvcounts, ¶ms.rdispls}; + for (size_t i = 0; i < kAllToAllVInputNums; ++i) { + auto addr = context.MutableInput(i); + GE_CHECK_NOTNULL(addr); + *input_addrs[i] = addr->MutableData(); + } + auto recv_tv = context.MutableOutput(0); + GE_CHECK_NOTNULL(recv_tv); + params.recvbuf = recv_tv->MutableData(); + + const NodeItem &node_item = context.GetNodeItem(); + const OpDescPtr op_desc = node_item.GetOpDesc(); + auto input_desc = node_item.MutableInputDesc(0); + GE_CHECK_NOTNULL(input_desc); + ge::DataType src_data_type = input_desc->GetDataType(); + auto iter = kConstOpHcclDataType.find(static_cast(src_data_type)); + if (iter == kConstOpHcclDataType.end()) { + REPORT_INNER_ERROR("E19999", "%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(), + TypeUtils::DataTypeToSerialString(src_data_type).c_str()); + GELOGE(PARAM_INVALID, "[Find][DataType]%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(), + TypeUtils::DataTypeToSerialString(src_data_type).c_str()); + return PARAM_INVALID; + } + params.sendtype = iter->second; + params.recvtype = iter->second; + + return SUCCESS; +} + +Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams ¶ms) { + void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, + ¶ms.recvcounts, ¶ms.rdispls}; + for (size_t i = 0; i < kGatherAllToAllVInputNums; ++i) { + auto addr = context.MutableInput(i); + GE_CHECK_NOTNULL(addr); + *input_addrs[i] = addr->MutableData(); + } + auto recv_tv = context.MutableOutput(0); + GE_CHECK_NOTNULL(recv_tv); + params.recvbuf = recv_tv->MutableData(); + auto gathered_tv = context.MutableOutput(1); + GE_CHECK_NOTNULL(gathered_tv); + params.gatheredbuf = gathered_tv->MutableData(); + + const NodeItem &node_item = context.GetNodeItem(); + const OpDescPtr op_desc = node_item.GetOpDesc(); + + ge::DataType data_type = ge::DT_FLOAT; + (void)ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, data_type); + auto iter = kConstOpHcclDataType.find(static_cast(data_type)); + if (iter == kConstOpHcclDataType.end()) { + REPORT_INNER_ERROR("E19999", "%s received datatype:%s not support.", op_desc->GetName().c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + GELOGE(PARAM_INVALID, "[Find][DataType]%s received datatype:%s not support.", op_desc->GetName().c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + return PARAM_INVALID; + } + params.recvtype = iter->second; + + int64_t addr_len; + (void) ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); + params.addrLength = static_cast(addr_len); + + return SUCCESS; +} + +Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GELOGI("[%s] AllToAllNodeTask::ExecuteAsync in.", context.GetNodeName()); + + TaskContext *p_ctx = &context; + auto callback = [p_ctx, done_callback](HcclResult status){ + if (status != HCCL_SUCCESS) { + GELOGE(HCCL_E_INTERNAL, "[%s] AllToAllNodeTask execute failed.", p_ctx->GetNodeName()); + p_ctx->SetStatus(FAILED); + } + done_callback(); + GELOGI("[%s] AllToAllNodeTask callback successfully.", p_ctx->GetNodeName()); + }; + + if (context.GetNodeItem().NodeType() == HCOMALLTOALLV) { + auto HcomExecEnqueueAllToAllV = (HcclResult(*)(HcomAllToAllVParams, std::function))dlsym( + context.handle_, "HcomExecEnqueueAllToAllV"); + if (HcomExecEnqueueAllToAllV == nullptr) { + GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueAllToAllV] for node:%s.",context.GetNodeName()); + return FAILED; + } + HcomAllToAllVParams params; + GE_CHK_STATUS_RET(BuildAllToAllVparams(context, params)); + HcclResult hccl_ret = HcomExecEnqueueAllToAllV(params, callback); + if (hccl_ret != HCCL_SUCCESS) { + GELOGE(HCCL_E_INTERNAL, "AllToAllV teak enqueue failed for node [%s].", context.GetNodeName()); + return HCCL_E_INTERNAL; + } + } else { + auto HcomExecEnqueueGatherAllToAllV = + (HcclResult(*)(HcomGatherAllToAllVParams, std::function))dlsym( + context.handle_, "HcomExecEnqueueGatherAllToAllV"); + if (HcomExecEnqueueGatherAllToAllV == nullptr) { + GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueGatherAllToAllV] for node:%s.", context.GetNodeName()); + return FAILED; + } + HcomGatherAllToAllVParams params; + GE_CHK_STATUS_RET(BuildGatherAllToAllParams(context, params)); + HcclResult hccl_ret = HcomExecEnqueueGatherAllToAllV(params, callback); + if (hccl_ret != HCCL_SUCCESS) { + GELOGE(HCCL_E_INTERNAL, "GatherAllToAllV teak enqueue failed for node [%s].", context.GetNodeName()); + return HCCL_E_INTERNAL; + } + } + GELOGI("[%s] AllToAllNodeTask::ExecuteAsync success.", context.GetNodeName()); + return SUCCESS; +} + Status HcclNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } Status HcclNodeTask::Init(TaskContext &context) { @@ -379,6 +499,8 @@ Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, GE_CHECK_NOTNULL(node); if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) { task = MakeShared(); + } else if (kAllToAllTypes.count(node->GetType()) > 0) { + task = MakeShared(); } else { task = MakeShared(); } diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.h b/ge/hybrid/node_executor/hccl/hccl_node_executor.h index 9e6d41a4..b020208d 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.h +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.h @@ -65,6 +65,22 @@ class RdmaNodeTask : public NodeTask { bool skip_flag_; }; + +class AllToAllNodeTask : public NodeTask { + public: + AllToAllNodeTask() = default; + + ~AllToAllNodeTask() = default; + + Status UpdateArgs(TaskContext &context) override { return SUCCESS; } + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + Status Init(TaskContext &context) override { return SUCCESS; } + + private: + std::mutex hccl_mutex_; + std::condition_variable cond_; +}; + class HcclNodeExecutor : public NodeExecutor { public: Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const; diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index 4242118d..811d5eed 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -441,6 +441,8 @@ REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); REGISTER_OPTYPE_DECLARE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); +REGISTER_OPTYPE_DECLARE(HCOMALLTOALLV, "HcomAllToAllV"); +REGISTER_OPTYPE_DECLARE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV"); REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); diff --git a/tests/depends/hccl/src/hccl_stub.cc b/tests/depends/hccl/src/hccl_stub.cc index b9b9d4f6..5f5e513c 100644 --- a/tests/depends/hccl/src/hccl_stub.cc +++ b/tests/depends/hccl/src/hccl_stub.cc @@ -42,3 +42,14 @@ HcclResult hcom_reduce_scatter(const char *tag, void *input_ptr, void *output_pt HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) { return HCCL_SUCCESS; } + +HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function callback) { + return HCCL_SUCCESS; +} + +HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params, +std::function callback) { + return HCCL_SUCCESS; +} + + diff --git a/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc index c36d6ea5..afaf067e 100644 --- a/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc @@ -1,108 +1,240 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -#define private public -#define protected public -#include "graph/runtime_inference_context.h" -#include "hybrid/executor/subgraph_context.h" -#include "hybrid/node_executor/hccl/hccl_node_executor.h" -#undef protected -#undef private - -using namespace std; -using namespace testing; -namespace ge { -using namespace hybrid; - -class UtestHcclNodeExecutor : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { - OpDescPtr op_desc = std::make_shared(name, type); - op_desc->SetStreamId(0); - static int32_t index = 0; - op_desc->SetId(index++); - - GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); - TensorUtils::SetSize(tensor, 64); - vector input_offset; - for (int i = 0; i < in_num; i++) { - op_desc->AddInputDesc(tensor); - input_offset.emplace_back(i * 64); - } - op_desc->SetInputOffset(input_offset); - - vector output_offset; - for (int i = 0; i < out_num; i++) { - op_desc->AddOutputDesc(tensor); - output_offset.emplace_back(in_num * 64 + i * 64); - } - op_desc->SetOutputOffset(output_offset); - - return graph.AddNode(op_desc); -} - -TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { - ComputeGraphPtr graph = std::make_shared("test"); - NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0); - std::unique_ptr new_node; - ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); - NodeItem *node_item = new_node.get(); - node_item->input_start = 0; - node_item->output_start = 0; - - GraphItem graph_item; - GraphExecutionContext graph_context; - SubgraphContext subgraph_context(&graph_item, &graph_context); - ASSERT_EQ(subgraph_context.Init(), SUCCESS); - - auto node_state = subgraph_context.GetOrCreateNodeState(node_item); - ASSERT_NE(node_state, nullptr); - - RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id)); - RuntimeInferenceContext *ctx = nullptr; - RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx); - - Shape s({1, 3}); - TensorDesc tensor_desc(s); - Tensor tensor(tensor_desc); - std::vector data = {1, 2, 3, 4}; - tensor.SetData(data); - ctx->SetTensor(1, 0, tensor.Clone()); - - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - vector addr_infos; - shared_ptr task = MakeShared(); - task->remote_index_ = {1, 0}; - ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); - - Shape s2({1}); - TensorDesc tensor_desc2(s2); - Tensor tensor2(tensor_desc2); - ctx->SetTensor(1, 0, tensor2.Clone()); - task->ExtractTensor(*unique_task_context, addr_infos); - ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); - RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); -} -} // namespace ge \ No newline at end of file +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#define private public +#define protected public +#include "graph/runtime_inference_context.h" +#include "hybrid/executor/subgraph_context.h" +#include "hybrid/node_executor/hccl/hccl_node_executor.h" +#undef protected +#undef private + +using namespace std; +using namespace testing; +namespace { +const string kHcclSoPath = "../build/tests/depends/hccl/libhccl_stub.so"; +} +namespace ge { +using namespace hybrid; + +class UtestHcclNodeExecutor : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { + OpDescPtr op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); + TensorUtils::SetSize(tensor, 64); + vector input_offset; + for (int i = 0; i < in_num; i++) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(i * 64); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; i++) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(in_num * 64 + i * 64); + } + op_desc->SetOutputOffset(output_offset); + + return graph.AddNode(op_desc); +} + +TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { + ComputeGraphPtr graph = std::make_shared("test"); + NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0); + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + node_item->input_start = 0; + node_item->output_start = 0; + + GraphItem graph_item; + GraphExecutionContext graph_context; + SubgraphContext subgraph_context(&graph_item, &graph_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + + RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id)); + RuntimeInferenceContext *ctx = nullptr; + RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx); + + Shape s({1, 3}); + TensorDesc tensor_desc(s); + Tensor tensor(tensor_desc); + std::vector data = {1, 2, 3, 4}; + tensor.SetData(data); + ctx->SetTensor(1, 0, tensor.Clone()); + + auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); + vector addr_infos; + shared_ptr task = MakeShared(); + task->remote_index_ = {1, 0}; + ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); + + Shape s2({1}); + TensorDesc tensor_desc2(s2); + Tensor tensor2(tensor_desc2); + ctx->SetTensor(1, 0, tensor2.Clone()); + task->ExtractTensor(*unique_task_context, addr_infos); + ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); + RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); +} + +TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) { + ComputeGraphPtr graph = std::make_shared("test"); + GeModelPtr ge_sub_model = std::make_shared(); + GeRootModelPtr ge_root_model = std::make_shared(graph); + ge_root_model->SetModelName("test_name"); + ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); + HybridModel hybrid_model(ge_root_model); + + + NodePtr node = CreateNode(*graph, "gatheralltoallv", HCOMGATHERALLTOALLV, 4, 2); + + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + hybrid_model.node_items_[node] = std::move(new_node); + node_item->input_start = 0; + node_item->output_start = 0; + + GraphItem graph_item; + graph_item.node_items_.emplace_back(node_item); + graph_item.total_inputs_ = 4; + graph_item.total_outputs_ = 2; + + GraphExecutionContext graph_context; + SubgraphContext subgraph_context(&graph_item, &graph_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + graph_context.callback_manager = std::unique_ptr(new CallbackManager()); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + + auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); + ASSERT_NE(unique_task_context, nullptr); + auto shared_task_context = std::shared_ptr(unique_task_context.release()); + node_state->SetTaskContext(shared_task_context); + + for (int i=0; i<4; ++i) { + uint64_t value_0 = 512; + TensorValue in_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetInput(*node_item, 0, in_tensor0); + } + + uint64_t value_0 = 512; + TensorValue out_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetOutput(*node_item, 0, out_tensor0); + + uint64_t value_1 = 512; + TensorValue out_tensor1(&value_1, sizeof(value_1)); + subgraph_context.SetOutput(*node_item, 1, out_tensor1); + + NodeTaskPtr task = nullptr; + HcclNodeExecutor node_executor; + ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); + ASSERT_NE(task, nullptr); + + auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL); + ASSERT_NE(handle, nullptr); + node_state->GetTaskContext()->handle_ = handle; + std::function done = []() {}; + ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); + + if (handle = nullptr) { + dlclose(handle); + } +} + +TEST_F(UtestHcclNodeExecutor, alltoallv_execute) { + ComputeGraphPtr graph = std::make_shared("test"); + GeModelPtr ge_sub_model = std::make_shared(); + GeRootModelPtr ge_root_model = std::make_shared(graph); + ge_root_model->SetModelName("test_name"); + ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); + HybridModel hybrid_model(ge_root_model); + + + NodePtr node = CreateNode(*graph, "alltoallv", HCOMALLTOALLV, 5, 1); + + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + hybrid_model.node_items_[node] = std::move(new_node); + node_item->input_start = 0; + node_item->output_start = 0; + + GraphItem graph_item; + graph_item.node_items_.emplace_back(node_item); + graph_item.total_inputs_ = 5; + graph_item.total_outputs_ = 1; + + GraphExecutionContext graph_context; + SubgraphContext subgraph_context(&graph_item, &graph_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + graph_context.callback_manager = std::unique_ptr(new CallbackManager()); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + + auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); + ASSERT_NE(unique_task_context, nullptr); + auto shared_task_context = std::shared_ptr(unique_task_context.release()); + node_state->SetTaskContext(shared_task_context); + + for (int i=0; i<5; ++i) { + uint64_t value_0 = 512; + TensorValue in_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetInput(*node_item, 0, in_tensor0); + } + + uint64_t value_1 = 512; + TensorValue out_tensor0(&value_1, sizeof(value_1)); + subgraph_context.SetOutput(*node_item, 0, out_tensor0); + NodeTaskPtr task = nullptr; + HcclNodeExecutor node_executor; + ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); + ASSERT_NE(task, nullptr); + + auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL); + ASSERT_NE(handle, nullptr); + node_state->GetTaskContext()->handle_ = handle; + + std::function done = []() {}; + ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); + + if (handle = nullptr) { + dlclose(handle); + } +} +} // namespace ge + diff --git a/third_party/fwkacllib/inc/hccl/base.h b/third_party/fwkacllib/inc/hccl/base.h index 9facd20c..e57563b3 100644 --- a/third_party/fwkacllib/inc/hccl/base.h +++ b/third_party/fwkacllib/inc/hccl/base.h @@ -123,6 +123,30 @@ struct HcomRemoteAccessAddrInfo { u64 length; // Memory Length in Bytes }; +struct HcomAllToAllVParams { + void *sendbuf; + void *sendcounts; + void *sdispls; + HcclDataType sendtype; + void *recvbuf; + void *recvcounts; + void *rdispls; + HcclDataType recvtype; + const char *group; +}; + +struct HcomGatherAllToAllVParams { + void *addrInfo; + void *addrInfoCountPerRank; + void *recvbuf; + void *recvcounts; + void *rdispls; + void *gatheredbuf; + s32 addrLength; + HcclDataType recvtype; + const char *group; +}; + #ifdef __cplusplus } #endif // __cplusplus diff --git a/third_party/fwkacllib/inc/hccl/hcom.h b/third_party/fwkacllib/inc/hccl/hcom.h index 972f470c..955764d6 100644 --- a/third_party/fwkacllib/inc/hccl/hcom.h +++ b/third_party/fwkacllib/inc/hccl/hcom.h @@ -164,6 +164,11 @@ HcclResult HcomExecEnqueueRemoteAccess(const std::string& remoteAccessType, const std::vector& addrInfos, std::function callback); +HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function callback); + +HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params, + std::function callback); + /** * @brief Register memories and init resources for remote access. *