@@ -392,6 +392,8 @@ REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); | |||||
REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | ||||
REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | ||||
REGISTER_OPTYPE_DEFINE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | ||||
REGISTER_OPTYPE_DEFINE(HCOMALLTOALLV, "HcomAllToAllV"); | |||||
REGISTER_OPTYPE_DEFINE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV"); | |||||
REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); | REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); | ||||
REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | ||||
@@ -22,8 +22,8 @@ | |||||
#include "graph/manager/util/hcom_util.h" | #include "graph/manager/util/hcom_util.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "hccl/hcom.h" | |||||
#include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
#include "hccl/hcom.h" | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
@@ -31,9 +31,14 @@ constexpr size_t kVarTableDims = 2; | |||||
constexpr size_t kVarTableRowCnt = 3; | constexpr size_t kVarTableRowCnt = 3; | ||||
constexpr size_t kVarTableIdxAddr = 1; | constexpr size_t kVarTableIdxAddr = 1; | ||||
constexpr size_t kVarTableIdxLen = 2; | constexpr size_t kVarTableIdxLen = 2; | ||||
// input anchor nums according to IR | |||||
constexpr size_t kAllToAllVInputNums = 5; | |||||
constexpr size_t kGatherAllToAllVInputNums = 4; | |||||
const std::set<std::string> kRdmaReadTypes = { HCOMREMOTEREAD, HCOMREMOTEREFREAD }; | const std::set<std::string> kRdmaReadTypes = { HCOMREMOTEREAD, HCOMREMOTEREFREAD }; | ||||
const std::set<std::string> kRdmaWriteTypes = { HCOMREMOTEWRITE, HCOMREMOTESCATTERWRITE }; | const std::set<std::string> kRdmaWriteTypes = { HCOMREMOTEWRITE, HCOMREMOTESCATTERWRITE }; | ||||
const std::set<std::string> kRdmaScatterTypes = { HCOMREMOTEREFREAD, HCOMREMOTESCATTERWRITE }; | const std::set<std::string> kRdmaScatterTypes = { HCOMREMOTEREFREAD, HCOMREMOTESCATTERWRITE }; | ||||
const std::set<std::string> kAllToAllTypes = {HCOMALLTOALLV, HCOMGATHERALLTOALLV}; | |||||
} // namespace | } // namespace | ||||
namespace hybrid { | namespace hybrid { | ||||
@@ -349,6 +354,121 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { | |||||
void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, | |||||
¶ms.recvcounts, ¶ms.rdispls}; | |||||
for (size_t i = 0; i < kAllToAllVInputNums; ++i) { | |||||
auto addr = context.MutableInput(i); | |||||
GE_CHECK_NOTNULL(addr); | |||||
*input_addrs[i] = addr->MutableData(); | |||||
} | |||||
auto recv_tv = context.MutableOutput(0); | |||||
GE_CHECK_NOTNULL(recv_tv); | |||||
params.recvbuf = recv_tv->MutableData(); | |||||
const NodeItem &node_item = context.GetNodeItem(); | |||||
const OpDescPtr op_desc = node_item.GetOpDesc(); | |||||
auto input_desc = node_item.MutableInputDesc(0); | |||||
GE_CHECK_NOTNULL(input_desc); | |||||
ge::DataType src_data_type = input_desc->GetDataType(); | |||||
auto iter = kConstOpHcclDataType.find(static_cast<int64_t>(src_data_type)); | |||||
if (iter == kConstOpHcclDataType.end()) { | |||||
REPORT_INNER_ERROR("E19999", "%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(), | |||||
TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
GELOGE(PARAM_INVALID, "[Find][DataType]%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(), | |||||
TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
params.sendtype = iter->second; | |||||
params.recvtype = iter->second; | |||||
return SUCCESS; | |||||
} | |||||
Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams ¶ms) { | |||||
void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, | |||||
¶ms.recvcounts, ¶ms.rdispls}; | |||||
for (size_t i = 0; i < kGatherAllToAllVInputNums; ++i) { | |||||
auto addr = context.MutableInput(i); | |||||
GE_CHECK_NOTNULL(addr); | |||||
*input_addrs[i] = addr->MutableData(); | |||||
} | |||||
auto recv_tv = context.MutableOutput(0); | |||||
GE_CHECK_NOTNULL(recv_tv); | |||||
params.recvbuf = recv_tv->MutableData(); | |||||
auto gathered_tv = context.MutableOutput(1); | |||||
GE_CHECK_NOTNULL(gathered_tv); | |||||
params.gatheredbuf = gathered_tv->MutableData(); | |||||
const NodeItem &node_item = context.GetNodeItem(); | |||||
const OpDescPtr op_desc = node_item.GetOpDesc(); | |||||
ge::DataType data_type = ge::DT_FLOAT; | |||||
(void)ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, data_type); | |||||
auto iter = kConstOpHcclDataType.find(static_cast<int64_t>(data_type)); | |||||
if (iter == kConstOpHcclDataType.end()) { | |||||
REPORT_INNER_ERROR("E19999", "%s received datatype:%s not support.", op_desc->GetName().c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
GELOGE(PARAM_INVALID, "[Find][DataType]%s received datatype:%s not support.", op_desc->GetName().c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
params.recvtype = iter->second; | |||||
int64_t addr_len; | |||||
(void) ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); | |||||
params.addrLength = static_cast<int>(addr_len); | |||||
return SUCCESS; | |||||
} | |||||
Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | |||||
GELOGI("[%s] AllToAllNodeTask::ExecuteAsync in.", context.GetNodeName()); | |||||
TaskContext *p_ctx = &context; | |||||
auto callback = [p_ctx, done_callback](HcclResult status){ | |||||
if (status != HCCL_SUCCESS) { | |||||
GELOGE(HCCL_E_INTERNAL, "[%s] AllToAllNodeTask execute failed.", p_ctx->GetNodeName()); | |||||
p_ctx->SetStatus(FAILED); | |||||
} | |||||
done_callback(); | |||||
GELOGI("[%s] AllToAllNodeTask callback successfully.", p_ctx->GetNodeName()); | |||||
}; | |||||
if (context.GetNodeItem().NodeType() == HCOMALLTOALLV) { | |||||
auto HcomExecEnqueueAllToAllV = (HcclResult(*)(HcomAllToAllVParams, std::function<void(HcclResult status)>))dlsym( | |||||
context.handle_, "HcomExecEnqueueAllToAllV"); | |||||
if (HcomExecEnqueueAllToAllV == nullptr) { | |||||
GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueAllToAllV] for node:%s.",context.GetNodeName()); | |||||
return FAILED; | |||||
} | |||||
HcomAllToAllVParams params; | |||||
GE_CHK_STATUS_RET(BuildAllToAllVparams(context, params)); | |||||
HcclResult hccl_ret = HcomExecEnqueueAllToAllV(params, callback); | |||||
if (hccl_ret != HCCL_SUCCESS) { | |||||
GELOGE(HCCL_E_INTERNAL, "AllToAllV teak enqueue failed for node [%s].", context.GetNodeName()); | |||||
return HCCL_E_INTERNAL; | |||||
} | |||||
} else { | |||||
auto HcomExecEnqueueGatherAllToAllV = | |||||
(HcclResult(*)(HcomGatherAllToAllVParams, std::function<void(HcclResult status)>))dlsym( | |||||
context.handle_, "HcomExecEnqueueGatherAllToAllV"); | |||||
if (HcomExecEnqueueGatherAllToAllV == nullptr) { | |||||
GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueGatherAllToAllV] for node:%s.", context.GetNodeName()); | |||||
return FAILED; | |||||
} | |||||
HcomGatherAllToAllVParams params; | |||||
GE_CHK_STATUS_RET(BuildGatherAllToAllParams(context, params)); | |||||
HcclResult hccl_ret = HcomExecEnqueueGatherAllToAllV(params, callback); | |||||
if (hccl_ret != HCCL_SUCCESS) { | |||||
GELOGE(HCCL_E_INTERNAL, "GatherAllToAllV teak enqueue failed for node [%s].", context.GetNodeName()); | |||||
return HCCL_E_INTERNAL; | |||||
} | |||||
} | |||||
GELOGI("[%s] AllToAllNodeTask::ExecuteAsync success.", context.GetNodeName()); | |||||
return SUCCESS; | |||||
} | |||||
Status HcclNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } | Status HcclNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } | ||||
Status HcclNodeTask::Init(TaskContext &context) { | Status HcclNodeTask::Init(TaskContext &context) { | ||||
@@ -379,6 +499,8 @@ Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) { | if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) { | ||||
task = MakeShared<RdmaNodeTask>(); | task = MakeShared<RdmaNodeTask>(); | ||||
} else if (kAllToAllTypes.count(node->GetType()) > 0) { | |||||
task = MakeShared<AllToAllNodeTask>(); | |||||
} else { | } else { | ||||
task = MakeShared<HcclNodeTask>(); | task = MakeShared<HcclNodeTask>(); | ||||
} | } | ||||
@@ -65,6 +65,22 @@ class RdmaNodeTask : public NodeTask { | |||||
bool skip_flag_; | bool skip_flag_; | ||||
}; | }; | ||||
class AllToAllNodeTask : public NodeTask { | |||||
public: | |||||
AllToAllNodeTask() = default; | |||||
~AllToAllNodeTask() = default; | |||||
Status UpdateArgs(TaskContext &context) override { return SUCCESS; } | |||||
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | |||||
Status Init(TaskContext &context) override { return SUCCESS; } | |||||
private: | |||||
std::mutex hccl_mutex_; | |||||
std::condition_variable cond_; | |||||
}; | |||||
class HcclNodeExecutor : public NodeExecutor { | class HcclNodeExecutor : public NodeExecutor { | ||||
public: | public: | ||||
Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const; | Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const; | ||||
@@ -441,6 +441,8 @@ REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); | |||||
REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | ||||
REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | ||||
REGISTER_OPTYPE_DECLARE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); | ||||
REGISTER_OPTYPE_DECLARE(HCOMALLTOALLV, "HcomAllToAllV"); | |||||
REGISTER_OPTYPE_DECLARE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV"); | |||||
REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | ||||
REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | ||||
@@ -42,3 +42,14 @@ HcclResult hcom_reduce_scatter(const char *tag, void *input_ptr, void *output_pt | |||||
HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) { | HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) { | ||||
return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
} | } | ||||
HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function<void(HcclResult status)> callback) { | |||||
return HCCL_SUCCESS; | |||||
} | |||||
HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params, | |||||
std::function<void(HcclResult status)> callback) { | |||||
return HCCL_SUCCESS; | |||||
} | |||||
@@ -1,108 +1,240 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gmock/gmock.h> | |||||
#include <gtest/gtest.h> | |||||
#include <vector> | |||||
#define private public | |||||
#define protected public | |||||
#include "graph/runtime_inference_context.h" | |||||
#include "hybrid/executor/subgraph_context.h" | |||||
#include "hybrid/node_executor/hccl/hccl_node_executor.h" | |||||
#undef protected | |||||
#undef private | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
using namespace hybrid; | |||||
class UtestHcclNodeExecutor : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
op_desc->SetStreamId(0); | |||||
static int32_t index = 0; | |||||
op_desc->SetId(index++); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
TensorUtils::SetSize(tensor, 64); | |||||
vector<int64_t> input_offset; | |||||
for (int i = 0; i < in_num; i++) { | |||||
op_desc->AddInputDesc(tensor); | |||||
input_offset.emplace_back(i * 64); | |||||
} | |||||
op_desc->SetInputOffset(input_offset); | |||||
vector<int64_t> output_offset; | |||||
for (int i = 0; i < out_num; i++) { | |||||
op_desc->AddOutputDesc(tensor); | |||||
output_offset.emplace_back(in_num * 64 + i * 64); | |||||
} | |||||
op_desc->SetOutputOffset(output_offset); | |||||
return graph.AddNode(op_desc); | |||||
} | |||||
TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0); | |||||
std::unique_ptr<NodeItem> new_node; | |||||
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
NodeItem *node_item = new_node.get(); | |||||
node_item->input_start = 0; | |||||
node_item->output_start = 0; | |||||
GraphItem graph_item; | |||||
GraphExecutionContext graph_context; | |||||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
ASSERT_NE(node_state, nullptr); | |||||
RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id)); | |||||
RuntimeInferenceContext *ctx = nullptr; | |||||
RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx); | |||||
Shape s({1, 3}); | |||||
TensorDesc tensor_desc(s); | |||||
Tensor tensor(tensor_desc); | |||||
std::vector<uint8_t> data = {1, 2, 3, 4}; | |||||
tensor.SetData(data); | |||||
ctx->SetTensor(1, 0, tensor.Clone()); | |||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
vector<HcomRemoteAccessAddrInfo> addr_infos; | |||||
shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | |||||
task->remote_index_ = {1, 0}; | |||||
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
Shape s2({1}); | |||||
TensorDesc tensor_desc2(s2); | |||||
Tensor tensor2(tensor_desc2); | |||||
ctx->SetTensor(1, 0, tensor2.Clone()); | |||||
task->ExtractTensor(*unique_task_context, addr_infos); | |||||
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); | |||||
} | |||||
} // namespace ge | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gmock/gmock.h> | |||||
#include <gtest/gtest.h> | |||||
#include <vector> | |||||
#define private public | |||||
#define protected public | |||||
#include "graph/runtime_inference_context.h" | |||||
#include "hybrid/executor/subgraph_context.h" | |||||
#include "hybrid/node_executor/hccl/hccl_node_executor.h" | |||||
#undef protected | |||||
#undef private | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace { | |||||
const string kHcclSoPath = "../build/tests/depends/hccl/libhccl_stub.so"; | |||||
} | |||||
namespace ge { | |||||
using namespace hybrid; | |||||
class UtestHcclNodeExecutor : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
op_desc->SetStreamId(0); | |||||
static int32_t index = 0; | |||||
op_desc->SetId(index++); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
TensorUtils::SetSize(tensor, 64); | |||||
vector<int64_t> input_offset; | |||||
for (int i = 0; i < in_num; i++) { | |||||
op_desc->AddInputDesc(tensor); | |||||
input_offset.emplace_back(i * 64); | |||||
} | |||||
op_desc->SetInputOffset(input_offset); | |||||
vector<int64_t> output_offset; | |||||
for (int i = 0; i < out_num; i++) { | |||||
op_desc->AddOutputDesc(tensor); | |||||
output_offset.emplace_back(in_num * 64 + i * 64); | |||||
} | |||||
op_desc->SetOutputOffset(output_offset); | |||||
return graph.AddNode(op_desc); | |||||
} | |||||
TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0); | |||||
std::unique_ptr<NodeItem> new_node; | |||||
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
NodeItem *node_item = new_node.get(); | |||||
node_item->input_start = 0; | |||||
node_item->output_start = 0; | |||||
GraphItem graph_item; | |||||
GraphExecutionContext graph_context; | |||||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
ASSERT_NE(node_state, nullptr); | |||||
RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id)); | |||||
RuntimeInferenceContext *ctx = nullptr; | |||||
RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx); | |||||
Shape s({1, 3}); | |||||
TensorDesc tensor_desc(s); | |||||
Tensor tensor(tensor_desc); | |||||
std::vector<uint8_t> data = {1, 2, 3, 4}; | |||||
tensor.SetData(data); | |||||
ctx->SetTensor(1, 0, tensor.Clone()); | |||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
vector<HcomRemoteAccessAddrInfo> addr_infos; | |||||
shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | |||||
task->remote_index_ = {1, 0}; | |||||
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
Shape s2({1}); | |||||
TensorDesc tensor_desc2(s2); | |||||
Tensor tensor2(tensor_desc2); | |||||
ctx->SetTensor(1, 0, tensor2.Clone()); | |||||
task->ExtractTensor(*unique_task_context, addr_infos); | |||||
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); | |||||
} | |||||
TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
ge_root_model->SetModelName("test_name"); | |||||
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
HybridModel hybrid_model(ge_root_model); | |||||
NodePtr node = CreateNode(*graph, "gatheralltoallv", HCOMGATHERALLTOALLV, 4, 2); | |||||
std::unique_ptr<NodeItem> new_node; | |||||
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
NodeItem *node_item = new_node.get(); | |||||
hybrid_model.node_items_[node] = std::move(new_node); | |||||
node_item->input_start = 0; | |||||
node_item->output_start = 0; | |||||
GraphItem graph_item; | |||||
graph_item.node_items_.emplace_back(node_item); | |||||
graph_item.total_inputs_ = 4; | |||||
graph_item.total_outputs_ = 2; | |||||
GraphExecutionContext graph_context; | |||||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
ASSERT_NE(node_state, nullptr); | |||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
for (int i=0; i<4; ++i) { | |||||
uint64_t value_0 = 512; | |||||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||||
} | |||||
uint64_t value_0 = 512; | |||||
TensorValue out_tensor0(&value_0, sizeof(value_0)); | |||||
subgraph_context.SetOutput(*node_item, 0, out_tensor0); | |||||
uint64_t value_1 = 512; | |||||
TensorValue out_tensor1(&value_1, sizeof(value_1)); | |||||
subgraph_context.SetOutput(*node_item, 1, out_tensor1); | |||||
NodeTaskPtr task = nullptr; | |||||
HcclNodeExecutor node_executor; | |||||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
ASSERT_NE(task, nullptr); | |||||
auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||||
ASSERT_NE(handle, nullptr); | |||||
node_state->GetTaskContext()->handle_ = handle; | |||||
std::function<void()> done = []() {}; | |||||
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
if (handle = nullptr) { | |||||
dlclose(handle); | |||||
} | |||||
} | |||||
TEST_F(UtestHcclNodeExecutor, alltoallv_execute) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
ge_root_model->SetModelName("test_name"); | |||||
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
HybridModel hybrid_model(ge_root_model); | |||||
NodePtr node = CreateNode(*graph, "alltoallv", HCOMALLTOALLV, 5, 1); | |||||
std::unique_ptr<NodeItem> new_node; | |||||
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
NodeItem *node_item = new_node.get(); | |||||
hybrid_model.node_items_[node] = std::move(new_node); | |||||
node_item->input_start = 0; | |||||
node_item->output_start = 0; | |||||
GraphItem graph_item; | |||||
graph_item.node_items_.emplace_back(node_item); | |||||
graph_item.total_inputs_ = 5; | |||||
graph_item.total_outputs_ = 1; | |||||
GraphExecutionContext graph_context; | |||||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
ASSERT_NE(node_state, nullptr); | |||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
ASSERT_NE(unique_task_context, nullptr); | |||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
for (int i=0; i<5; ++i) { | |||||
uint64_t value_0 = 512; | |||||
TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||||
} | |||||
uint64_t value_1 = 512; | |||||
TensorValue out_tensor0(&value_1, sizeof(value_1)); | |||||
subgraph_context.SetOutput(*node_item, 0, out_tensor0); | |||||
NodeTaskPtr task = nullptr; | |||||
HcclNodeExecutor node_executor; | |||||
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
ASSERT_NE(task, nullptr); | |||||
auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||||
ASSERT_NE(handle, nullptr); | |||||
node_state->GetTaskContext()->handle_ = handle; | |||||
std::function<void()> done = []() {}; | |||||
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
if (handle = nullptr) { | |||||
dlclose(handle); | |||||
} | |||||
} | |||||
} // namespace ge | |||||
@@ -123,6 +123,30 @@ struct HcomRemoteAccessAddrInfo { | |||||
u64 length; // Memory Length in Bytes | u64 length; // Memory Length in Bytes | ||||
}; | }; | ||||
struct HcomAllToAllVParams { | |||||
void *sendbuf; | |||||
void *sendcounts; | |||||
void *sdispls; | |||||
HcclDataType sendtype; | |||||
void *recvbuf; | |||||
void *recvcounts; | |||||
void *rdispls; | |||||
HcclDataType recvtype; | |||||
const char *group; | |||||
}; | |||||
struct HcomGatherAllToAllVParams { | |||||
void *addrInfo; | |||||
void *addrInfoCountPerRank; | |||||
void *recvbuf; | |||||
void *recvcounts; | |||||
void *rdispls; | |||||
void *gatheredbuf; | |||||
s32 addrLength; | |||||
HcclDataType recvtype; | |||||
const char *group; | |||||
}; | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif // __cplusplus | #endif // __cplusplus | ||||
@@ -164,6 +164,11 @@ HcclResult HcomExecEnqueueRemoteAccess(const std::string& remoteAccessType, | |||||
const std::vector<HcomRemoteAccessAddrInfo>& addrInfos, | const std::vector<HcomRemoteAccessAddrInfo>& addrInfos, | ||||
std::function<void(HcclResult status)> callback); | std::function<void(HcclResult status)> callback); | ||||
HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function<void(HcclResult status)> callback); | |||||
HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params, | |||||
std::function<void(HcclResult status)> callback); | |||||
/** | /** | ||||
* @brief Register memories and init resources for remote access. | * @brief Register memories and init resources for remote access. | ||||
* | * | ||||