From: @isaacxr Reviewed-by: @wqtshg,@zhangxiaokun9 Signed-off-by:tags/v1.3.0
| @@ -392,6 +392,8 @@ REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); | |||||
| REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMALLTOALLV, "HcomAllToAllV"); | |||||
| REGISTER_OPTYPE_DEFINE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV"); | |||||
| REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); | REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); | ||||
| REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | ||||
| @@ -22,8 +22,8 @@ | |||||
| #include "graph/manager/util/hcom_util.h" | #include "graph/manager/util/hcom_util.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| #include "hccl/hcom.h" | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| #include "hccl/hcom.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| @@ -31,9 +31,14 @@ constexpr size_t kVarTableDims = 2; | |||||
| constexpr size_t kVarTableRowCnt = 3; | constexpr size_t kVarTableRowCnt = 3; | ||||
| constexpr size_t kVarTableIdxAddr = 1; | constexpr size_t kVarTableIdxAddr = 1; | ||||
| constexpr size_t kVarTableIdxLen = 2; | constexpr size_t kVarTableIdxLen = 2; | ||||
| // input anchor nums according to IR | |||||
| constexpr size_t kAllToAllVInputNums = 5; | |||||
| constexpr size_t kGatherAllToAllVInputNums = 4; | |||||
| const std::set<std::string> kRdmaReadTypes = { HCOMREMOTEREAD, HCOMREMOTEREFREAD }; | const std::set<std::string> kRdmaReadTypes = { HCOMREMOTEREAD, HCOMREMOTEREFREAD }; | ||||
| const std::set<std::string> kRdmaWriteTypes = { HCOMREMOTEWRITE, HCOMREMOTESCATTERWRITE }; | const std::set<std::string> kRdmaWriteTypes = { HCOMREMOTEWRITE, HCOMREMOTESCATTERWRITE }; | ||||
| const std::set<std::string> kRdmaScatterTypes = { HCOMREMOTEREFREAD, HCOMREMOTESCATTERWRITE }; | const std::set<std::string> kRdmaScatterTypes = { HCOMREMOTEREFREAD, HCOMREMOTESCATTERWRITE }; | ||||
| const std::set<std::string> kAllToAllTypes = {HCOMALLTOALLV, HCOMGATHERALLTOALLV}; | |||||
| } // namespace | } // namespace | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -349,6 +354,121 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { | |||||
| void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, | |||||
| ¶ms.recvcounts, ¶ms.rdispls}; | |||||
| for (size_t i = 0; i < kAllToAllVInputNums; ++i) { | |||||
| auto addr = context.MutableInput(i); | |||||
| GE_CHECK_NOTNULL(addr); | |||||
| *input_addrs[i] = addr->MutableData(); | |||||
| } | |||||
| auto recv_tv = context.MutableOutput(0); | |||||
| GE_CHECK_NOTNULL(recv_tv); | |||||
| params.recvbuf = recv_tv->MutableData(); | |||||
| const NodeItem &node_item = context.GetNodeItem(); | |||||
| const OpDescPtr op_desc = node_item.GetOpDesc(); | |||||
| auto input_desc = node_item.MutableInputDesc(0); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| ge::DataType src_data_type = input_desc->GetDataType(); | |||||
| auto iter = kConstOpHcclDataType.find(static_cast<int64_t>(src_data_type)); | |||||
| if (iter == kConstOpHcclDataType.end()) { | |||||
| REPORT_INNER_ERROR("E19999", "%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(), | |||||
| TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Find][DataType]%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(), | |||||
| TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| params.sendtype = iter->second; | |||||
| params.recvtype = iter->second; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams ¶ms) { | |||||
| void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, | |||||
| ¶ms.recvcounts, ¶ms.rdispls}; | |||||
| for (size_t i = 0; i < kGatherAllToAllVInputNums; ++i) { | |||||
| auto addr = context.MutableInput(i); | |||||
| GE_CHECK_NOTNULL(addr); | |||||
| *input_addrs[i] = addr->MutableData(); | |||||
| } | |||||
| auto recv_tv = context.MutableOutput(0); | |||||
| GE_CHECK_NOTNULL(recv_tv); | |||||
| params.recvbuf = recv_tv->MutableData(); | |||||
| auto gathered_tv = context.MutableOutput(1); | |||||
| GE_CHECK_NOTNULL(gathered_tv); | |||||
| params.gatheredbuf = gathered_tv->MutableData(); | |||||
| const NodeItem &node_item = context.GetNodeItem(); | |||||
| const OpDescPtr op_desc = node_item.GetOpDesc(); | |||||
| ge::DataType data_type = ge::DT_FLOAT; | |||||
| (void)ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, data_type); | |||||
| auto iter = kConstOpHcclDataType.find(static_cast<int64_t>(data_type)); | |||||
| if (iter == kConstOpHcclDataType.end()) { | |||||
| REPORT_INNER_ERROR("E19999", "%s received datatype:%s not support.", op_desc->GetName().c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Find][DataType]%s received datatype:%s not support.", op_desc->GetName().c_str(), | |||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| params.recvtype = iter->second; | |||||
| int64_t addr_len; | |||||
| (void) ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); | |||||
| params.addrLength = static_cast<int>(addr_len); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | |||||
| GELOGI("[%s] AllToAllNodeTask::ExecuteAsync in.", context.GetNodeName()); | |||||
| TaskContext *p_ctx = &context; | |||||
| auto callback = [p_ctx, done_callback](HcclResult status){ | |||||
| if (status != HCCL_SUCCESS) { | |||||
| GELOGE(HCCL_E_INTERNAL, "[%s] AllToAllNodeTask execute failed.", p_ctx->GetNodeName()); | |||||
| p_ctx->SetStatus(FAILED); | |||||
| } | |||||
| done_callback(); | |||||
| GELOGI("[%s] AllToAllNodeTask callback successfully.", p_ctx->GetNodeName()); | |||||
| }; | |||||
| if (context.GetNodeItem().NodeType() == HCOMALLTOALLV) { | |||||
| auto HcomExecEnqueueAllToAllV = (HcclResult(*)(HcomAllToAllVParams, std::function<void(HcclResult status)>))dlsym( | |||||
| context.handle_, "HcomExecEnqueueAllToAllV"); | |||||
| if (HcomExecEnqueueAllToAllV == nullptr) { | |||||
| GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueAllToAllV] for node:%s.",context.GetNodeName()); | |||||
| return FAILED; | |||||
| } | |||||
| HcomAllToAllVParams params; | |||||
| GE_CHK_STATUS_RET(BuildAllToAllVparams(context, params)); | |||||
| HcclResult hccl_ret = HcomExecEnqueueAllToAllV(params, callback); | |||||
| if (hccl_ret != HCCL_SUCCESS) { | |||||
| GELOGE(HCCL_E_INTERNAL, "AllToAllV teak enqueue failed for node [%s].", context.GetNodeName()); | |||||
| return HCCL_E_INTERNAL; | |||||
| } | |||||
| } else { | |||||
| auto HcomExecEnqueueGatherAllToAllV = | |||||
| (HcclResult(*)(HcomGatherAllToAllVParams, std::function<void(HcclResult status)>))dlsym( | |||||
| context.handle_, "HcomExecEnqueueGatherAllToAllV"); | |||||
| if (HcomExecEnqueueGatherAllToAllV == nullptr) { | |||||
| GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueGatherAllToAllV] for node:%s.", context.GetNodeName()); | |||||
| return FAILED; | |||||
| } | |||||
| HcomGatherAllToAllVParams params; | |||||
| GE_CHK_STATUS_RET(BuildGatherAllToAllParams(context, params)); | |||||
| HcclResult hccl_ret = HcomExecEnqueueGatherAllToAllV(params, callback); | |||||
| if (hccl_ret != HCCL_SUCCESS) { | |||||
| GELOGE(HCCL_E_INTERNAL, "GatherAllToAllV teak enqueue failed for node [%s].", context.GetNodeName()); | |||||
| return HCCL_E_INTERNAL; | |||||
| } | |||||
| } | |||||
| GELOGI("[%s] AllToAllNodeTask::ExecuteAsync success.", context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HcclNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } | Status HcclNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } | ||||
| Status HcclNodeTask::Init(TaskContext &context) { | Status HcclNodeTask::Init(TaskContext &context) { | ||||
| @@ -379,6 +499,8 @@ Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, | |||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) { | if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) { | ||||
| task = MakeShared<RdmaNodeTask>(); | task = MakeShared<RdmaNodeTask>(); | ||||
| } else if (kAllToAllTypes.count(node->GetType()) > 0) { | |||||
| task = MakeShared<AllToAllNodeTask>(); | |||||
| } else { | } else { | ||||
| task = MakeShared<HcclNodeTask>(); | task = MakeShared<HcclNodeTask>(); | ||||
| } | } | ||||
| @@ -65,6 +65,22 @@ class RdmaNodeTask : public NodeTask { | |||||
| bool skip_flag_; | bool skip_flag_; | ||||
| }; | }; | ||||
| class AllToAllNodeTask : public NodeTask { | |||||
| public: | |||||
| AllToAllNodeTask() = default; | |||||
| ~AllToAllNodeTask() = default; | |||||
| Status UpdateArgs(TaskContext &context) override { return SUCCESS; } | |||||
| Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | |||||
| Status Init(TaskContext &context) override { return SUCCESS; } | |||||
| private: | |||||
| std::mutex hccl_mutex_; | |||||
| std::condition_variable cond_; | |||||
| }; | |||||
| class HcclNodeExecutor : public NodeExecutor { | class HcclNodeExecutor : public NodeExecutor { | ||||
| public: | public: | ||||
| Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const; | Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const; | ||||
| @@ -441,6 +441,8 @@ REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); | |||||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMALLTOALLV, "HcomAllToAllV"); | |||||
| REGISTER_OPTYPE_DECLARE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV"); | |||||
| REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | ||||
| REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | ||||
| @@ -42,3 +42,14 @@ HcclResult hcom_reduce_scatter(const char *tag, void *input_ptr, void *output_pt | |||||
| HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) { | HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) { | ||||
| return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
| } | } | ||||
| HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function<void(HcclResult status)> callback) { | |||||
| return HCCL_SUCCESS; | |||||
| } | |||||
| HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params, | |||||
| std::function<void(HcclResult status)> callback) { | |||||
| return HCCL_SUCCESS; | |||||
| } | |||||
| @@ -1,108 +1,240 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gmock/gmock.h> | |||||
| #include <gtest/gtest.h> | |||||
| #include <vector> | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "graph/runtime_inference_context.h" | |||||
| #include "hybrid/executor/subgraph_context.h" | |||||
| #include "hybrid/node_executor/hccl/hccl_node_executor.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| namespace ge { | |||||
| using namespace hybrid; | |||||
| class UtestHcclNodeExecutor : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| op_desc->SetStreamId(0); | |||||
| static int32_t index = 0; | |||||
| op_desc->SetId(index++); | |||||
| GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
| TensorUtils::SetSize(tensor, 64); | |||||
| vector<int64_t> input_offset; | |||||
| for (int i = 0; i < in_num; i++) { | |||||
| op_desc->AddInputDesc(tensor); | |||||
| input_offset.emplace_back(i * 64); | |||||
| } | |||||
| op_desc->SetInputOffset(input_offset); | |||||
| vector<int64_t> output_offset; | |||||
| for (int i = 0; i < out_num; i++) { | |||||
| op_desc->AddOutputDesc(tensor); | |||||
| output_offset.emplace_back(in_num * 64 + i * 64); | |||||
| } | |||||
| op_desc->SetOutputOffset(output_offset); | |||||
| return graph.AddNode(op_desc); | |||||
| } | |||||
| TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| GraphItem graph_item; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id)); | |||||
| RuntimeInferenceContext *ctx = nullptr; | |||||
| RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx); | |||||
| Shape s({1, 3}); | |||||
| TensorDesc tensor_desc(s); | |||||
| Tensor tensor(tensor_desc); | |||||
| std::vector<uint8_t> data = {1, 2, 3, 4}; | |||||
| tensor.SetData(data); | |||||
| ctx->SetTensor(1, 0, tensor.Clone()); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| vector<HcomRemoteAccessAddrInfo> addr_infos; | |||||
| shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | |||||
| task->remote_index_ = {1, 0}; | |||||
| ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
| Shape s2({1}); | |||||
| TensorDesc tensor_desc2(s2); | |||||
| Tensor tensor2(tensor_desc2); | |||||
| ctx->SetTensor(1, 0, tensor2.Clone()); | |||||
| task->ExtractTensor(*unique_task_context, addr_infos); | |||||
| ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
| RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); | |||||
| } | |||||
| } // namespace ge | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gmock/gmock.h> | |||||
| #include <gtest/gtest.h> | |||||
| #include <vector> | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "graph/runtime_inference_context.h" | |||||
| #include "hybrid/executor/subgraph_context.h" | |||||
| #include "hybrid/node_executor/hccl/hccl_node_executor.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| namespace { | |||||
| const string kHcclSoPath = "../build/tests/depends/hccl/libhccl_stub.so"; | |||||
| } | |||||
| namespace ge { | |||||
| using namespace hybrid; | |||||
| class UtestHcclNodeExecutor : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| op_desc->SetStreamId(0); | |||||
| static int32_t index = 0; | |||||
| op_desc->SetId(index++); | |||||
| GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
| TensorUtils::SetSize(tensor, 64); | |||||
| vector<int64_t> input_offset; | |||||
| for (int i = 0; i < in_num; i++) { | |||||
| op_desc->AddInputDesc(tensor); | |||||
| input_offset.emplace_back(i * 64); | |||||
| } | |||||
| op_desc->SetInputOffset(input_offset); | |||||
| vector<int64_t> output_offset; | |||||
| for (int i = 0; i < out_num; i++) { | |||||
| op_desc->AddOutputDesc(tensor); | |||||
| output_offset.emplace_back(in_num * 64 + i * 64); | |||||
| } | |||||
| op_desc->SetOutputOffset(output_offset); | |||||
| return graph.AddNode(op_desc); | |||||
| } | |||||
| TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| GraphItem graph_item; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id)); | |||||
| RuntimeInferenceContext *ctx = nullptr; | |||||
| RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx); | |||||
| Shape s({1, 3}); | |||||
| TensorDesc tensor_desc(s); | |||||
| Tensor tensor(tensor_desc); | |||||
| std::vector<uint8_t> data = {1, 2, 3, 4}; | |||||
| tensor.SetData(data); | |||||
| ctx->SetTensor(1, 0, tensor.Clone()); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| vector<HcomRemoteAccessAddrInfo> addr_infos; | |||||
| shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | |||||
| task->remote_index_ = {1, 0}; | |||||
| ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
| Shape s2({1}); | |||||
| TensorDesc tensor_desc2(s2); | |||||
| Tensor tensor2(tensor_desc2); | |||||
| ctx->SetTensor(1, 0, tensor2.Clone()); | |||||
| task->ExtractTensor(*unique_task_context, addr_infos); | |||||
| ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
| RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); | |||||
| } | |||||
| TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
| GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| NodePtr node = CreateNode(*graph, "gatheralltoallv", HCOMGATHERALLTOALLV, 4, 2); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| hybrid_model.node_items_[node] = std::move(new_node); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| GraphItem graph_item; | |||||
| graph_item.node_items_.emplace_back(node_item); | |||||
| graph_item.total_inputs_ = 4; | |||||
| graph_item.total_outputs_ = 2; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| ASSERT_NE(unique_task_context, nullptr); | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| for (int i=0; i<4; ++i) { | |||||
| uint64_t value_0 = 512; | |||||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
| subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||||
| } | |||||
| uint64_t value_0 = 512; | |||||
| TensorValue out_tensor0(&value_0, sizeof(value_0)); | |||||
| subgraph_context.SetOutput(*node_item, 0, out_tensor0); | |||||
| uint64_t value_1 = 512; | |||||
| TensorValue out_tensor1(&value_1, sizeof(value_1)); | |||||
| subgraph_context.SetOutput(*node_item, 1, out_tensor1); | |||||
| NodeTaskPtr task = nullptr; | |||||
| HcclNodeExecutor node_executor; | |||||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
| ASSERT_NE(task, nullptr); | |||||
| auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||||
| ASSERT_NE(handle, nullptr); | |||||
| node_state->GetTaskContext()->handle_ = handle; | |||||
| std::function<void()> done = []() {}; | |||||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
| if (handle = nullptr) { | |||||
| dlclose(handle); | |||||
| } | |||||
| } | |||||
| TEST_F(UtestHcclNodeExecutor, alltoallv_execute) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
| GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| NodePtr node = CreateNode(*graph, "alltoallv", HCOMALLTOALLV, 5, 1); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| hybrid_model.node_items_[node] = std::move(new_node); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| GraphItem graph_item; | |||||
| graph_item.node_items_.emplace_back(node_item); | |||||
| graph_item.total_inputs_ = 5; | |||||
| graph_item.total_outputs_ = 1; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| ASSERT_NE(unique_task_context, nullptr); | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| for (int i=0; i<5; ++i) { | |||||
| uint64_t value_0 = 512; | |||||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
| subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||||
| } | |||||
| uint64_t value_1 = 512; | |||||
| TensorValue out_tensor0(&value_1, sizeof(value_1)); | |||||
| subgraph_context.SetOutput(*node_item, 0, out_tensor0); | |||||
| NodeTaskPtr task = nullptr; | |||||
| HcclNodeExecutor node_executor; | |||||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
| ASSERT_NE(task, nullptr); | |||||
| auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||||
| ASSERT_NE(handle, nullptr); | |||||
| node_state->GetTaskContext()->handle_ = handle; | |||||
| std::function<void()> done = []() {}; | |||||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
| if (handle = nullptr) { | |||||
| dlclose(handle); | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -123,6 +123,30 @@ struct HcomRemoteAccessAddrInfo { | |||||
| u64 length; // Memory Length in Bytes | u64 length; // Memory Length in Bytes | ||||
| }; | }; | ||||
| struct HcomAllToAllVParams { | |||||
| void *sendbuf; | |||||
| void *sendcounts; | |||||
| void *sdispls; | |||||
| HcclDataType sendtype; | |||||
| void *recvbuf; | |||||
| void *recvcounts; | |||||
| void *rdispls; | |||||
| HcclDataType recvtype; | |||||
| const char *group; | |||||
| }; | |||||
| struct HcomGatherAllToAllVParams { | |||||
| void *addrInfo; | |||||
| void *addrInfoCountPerRank; | |||||
| void *recvbuf; | |||||
| void *recvcounts; | |||||
| void *rdispls; | |||||
| void *gatheredbuf; | |||||
| s32 addrLength; | |||||
| HcclDataType recvtype; | |||||
| const char *group; | |||||
| }; | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif // __cplusplus | #endif // __cplusplus | ||||
| @@ -164,6 +164,11 @@ HcclResult HcomExecEnqueueRemoteAccess(const std::string& remoteAccessType, | |||||
| const std::vector<HcomRemoteAccessAddrInfo>& addrInfos, | const std::vector<HcomRemoteAccessAddrInfo>& addrInfos, | ||||
| std::function<void(HcclResult status)> callback); | std::function<void(HcclResult status)> callback); | ||||
| HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function<void(HcclResult status)> callback); | |||||
| HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params, | |||||
| std::function<void(HcclResult status)> callback); | |||||
| /** | /** | ||||
| * @brief Register memories and init resources for remote access. | * @brief Register memories and init resources for remote access. | ||||
| * | * | ||||