Browse Source

alltoall node executor

tags/v1.3.0
isaacxr 3 years ago
parent
commit
e09bc32926
8 changed files with 423 additions and 109 deletions
  1. +2
    -0
      ge/common/types.cc
  2. +123
    -1
      ge/hybrid/node_executor/hccl/hccl_node_executor.cc
  3. +16
    -0
      ge/hybrid/node_executor/hccl/hccl_node_executor.h
  4. +2
    -0
      inc/framework/common/types.h
  5. +11
    -0
      tests/depends/hccl/src/hccl_stub.cc
  6. +240
    -108
      tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc
  7. +24
    -0
      third_party/fwkacllib/inc/hccl/base.h
  8. +5
    -0
      third_party/fwkacllib/inc/hccl/hcom.h

+ 2
- 0
ge/common/types.cc View File

@@ -392,6 +392,8 @@ REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead");
REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead");
REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite");
REGISTER_OPTYPE_DEFINE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite");
REGISTER_OPTYPE_DEFINE(HCOMALLTOALLV, "HcomAllToAllV");
REGISTER_OPTYPE_DEFINE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV");

REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign");
REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp");


+ 123
- 1
ge/hybrid/node_executor/hccl/hccl_node_executor.cc View File

@@ -22,8 +22,8 @@
#include "graph/manager/util/hcom_util.h"
#include "graph/utils/type_utils.h"
#include "graph/types.h"
#include "hccl/hcom.h"
#include "hybrid/executor/hybrid_execution_context.h"
#include "hccl/hcom.h"

namespace ge {
namespace {
@@ -31,9 +31,14 @@ constexpr size_t kVarTableDims = 2;
constexpr size_t kVarTableRowCnt = 3;
constexpr size_t kVarTableIdxAddr = 1;
constexpr size_t kVarTableIdxLen = 2;
// input anchor nums according to IR
constexpr size_t kAllToAllVInputNums = 5;
constexpr size_t kGatherAllToAllVInputNums = 4;

const std::set<std::string> kRdmaReadTypes = { HCOMREMOTEREAD, HCOMREMOTEREFREAD };
const std::set<std::string> kRdmaWriteTypes = { HCOMREMOTEWRITE, HCOMREMOTESCATTERWRITE };
const std::set<std::string> kRdmaScatterTypes = { HCOMREMOTEREFREAD, HCOMREMOTESCATTERWRITE };
const std::set<std::string> kAllToAllTypes = {HCOMALLTOALLV, HCOMGATHERALLTOALLV};
} // namespace
namespace hybrid {

@@ -349,6 +354,121 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
return SUCCESS;
}

Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams &params) {
void **input_addrs[kAllToAllVInputNums] = {&params.sendbuf, &params.sendcounts, &params.sdispls,
&params.recvcounts, &params.rdispls};
for (size_t i = 0; i < kAllToAllVInputNums; ++i) {
auto addr = context.MutableInput(i);
GE_CHECK_NOTNULL(addr);
*input_addrs[i] = addr->MutableData();
}
auto recv_tv = context.MutableOutput(0);
GE_CHECK_NOTNULL(recv_tv);
params.recvbuf = recv_tv->MutableData();

const NodeItem &node_item = context.GetNodeItem();
const OpDescPtr op_desc = node_item.GetOpDesc();
auto input_desc = node_item.MutableInputDesc(0);
GE_CHECK_NOTNULL(input_desc);
ge::DataType src_data_type = input_desc->GetDataType();
auto iter = kConstOpHcclDataType.find(static_cast<int64_t>(src_data_type));
if (iter == kConstOpHcclDataType.end()) {
REPORT_INNER_ERROR("E19999", "%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(),
TypeUtils::DataTypeToSerialString(src_data_type).c_str());
GELOGE(PARAM_INVALID, "[Find][DataType]%s alltoallv datatype:%s not support.", op_desc->GetName().c_str(),
TypeUtils::DataTypeToSerialString(src_data_type).c_str());
return PARAM_INVALID;
}
params.sendtype = iter->second;
params.recvtype = iter->second;

return SUCCESS;
}

Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams &params) {
void **input_addrs[kGatherAllToAllVInputNums] = {&params.addrInfo, &params.addrInfoCountPerRank,
&params.recvcounts, &params.rdispls};
for (size_t i = 0; i < kGatherAllToAllVInputNums; ++i) {
auto addr = context.MutableInput(i);
GE_CHECK_NOTNULL(addr);
*input_addrs[i] = addr->MutableData();
}
auto recv_tv = context.MutableOutput(0);
GE_CHECK_NOTNULL(recv_tv);
params.recvbuf = recv_tv->MutableData();
auto gathered_tv = context.MutableOutput(1);
GE_CHECK_NOTNULL(gathered_tv);
params.gatheredbuf = gathered_tv->MutableData();

const NodeItem &node_item = context.GetNodeItem();
const OpDescPtr op_desc = node_item.GetOpDesc();

ge::DataType data_type = ge::DT_FLOAT;
(void)ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, data_type);
auto iter = kConstOpHcclDataType.find(static_cast<int64_t>(data_type));
if (iter == kConstOpHcclDataType.end()) {
REPORT_INNER_ERROR("E19999", "%s received datatype:%s not support.", op_desc->GetName().c_str(),
TypeUtils::DataTypeToSerialString(data_type).c_str());
GELOGE(PARAM_INVALID, "[Find][DataType]%s received datatype:%s not support.", op_desc->GetName().c_str(),
TypeUtils::DataTypeToSerialString(data_type).c_str());
return PARAM_INVALID;
}
params.recvtype = iter->second;

int64_t addr_len;
(void) ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len);
params.addrLength = static_cast<int>(addr_len);

return SUCCESS;
}

Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
GELOGI("[%s] AllToAllNodeTask::ExecuteAsync in.", context.GetNodeName());

TaskContext *p_ctx = &context;
auto callback = [p_ctx, done_callback](HcclResult status){
if (status != HCCL_SUCCESS) {
GELOGE(HCCL_E_INTERNAL, "[%s] AllToAllNodeTask execute failed.", p_ctx->GetNodeName());
p_ctx->SetStatus(FAILED);
}
done_callback();
GELOGI("[%s] AllToAllNodeTask callback successfully.", p_ctx->GetNodeName());
};

if (context.GetNodeItem().NodeType() == HCOMALLTOALLV) {
auto HcomExecEnqueueAllToAllV = (HcclResult(*)(HcomAllToAllVParams, std::function<void(HcclResult status)>))dlsym(
context.handle_, "HcomExecEnqueueAllToAllV");
if (HcomExecEnqueueAllToAllV == nullptr) {
GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueAllToAllV] for node:%s.",context.GetNodeName());
return FAILED;
}
HcomAllToAllVParams params;
GE_CHK_STATUS_RET(BuildAllToAllVparams(context, params));
HcclResult hccl_ret = HcomExecEnqueueAllToAllV(params, callback);
if (hccl_ret != HCCL_SUCCESS) {
GELOGE(HCCL_E_INTERNAL, "AllToAllV teak enqueue failed for node [%s].", context.GetNodeName());
return HCCL_E_INTERNAL;
}
} else {
auto HcomExecEnqueueGatherAllToAllV =
(HcclResult(*)(HcomGatherAllToAllVParams, std::function<void(HcclResult status)>))dlsym(
context.handle_, "HcomExecEnqueueGatherAllToAllV");
if (HcomExecEnqueueGatherAllToAllV == nullptr) {
GELOGE(FAILED, "Failed to invoke function [HcomExecEnqueueGatherAllToAllV] for node:%s.", context.GetNodeName());
return FAILED;
}
HcomGatherAllToAllVParams params;
GE_CHK_STATUS_RET(BuildGatherAllToAllParams(context, params));
HcclResult hccl_ret = HcomExecEnqueueGatherAllToAllV(params, callback);
if (hccl_ret != HCCL_SUCCESS) {
GELOGE(HCCL_E_INTERNAL, "GatherAllToAllV teak enqueue failed for node [%s].", context.GetNodeName());
return HCCL_E_INTERNAL;
}
}
GELOGI("[%s] AllToAllNodeTask::ExecuteAsync success.", context.GetNodeName());
return SUCCESS;
}

Status HcclNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; }

Status HcclNodeTask::Init(TaskContext &context) {
@@ -379,6 +499,8 @@ Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node,
GE_CHECK_NOTNULL(node);
if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) {
task = MakeShared<RdmaNodeTask>();
} else if (kAllToAllTypes.count(node->GetType()) > 0) {
task = MakeShared<AllToAllNodeTask>();
} else {
task = MakeShared<HcclNodeTask>();
}


+ 16
- 0
ge/hybrid/node_executor/hccl/hccl_node_executor.h View File

@@ -65,6 +65,22 @@ class RdmaNodeTask : public NodeTask {
bool skip_flag_;
};


class AllToAllNodeTask : public NodeTask {
public:
AllToAllNodeTask() = default;

~AllToAllNodeTask() = default;

Status UpdateArgs(TaskContext &context) override { return SUCCESS; }
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
Status Init(TaskContext &context) override { return SUCCESS; }

private:
std::mutex hccl_mutex_;
std::condition_variable cond_;
};

class HcclNodeExecutor : public NodeExecutor {
public:
Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const;


+ 2
- 0
inc/framework/common/types.h View File

@@ -441,6 +441,8 @@ REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead");
REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead");
REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite");
REGISTER_OPTYPE_DECLARE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite");
REGISTER_OPTYPE_DECLARE(HCOMALLTOALLV, "HcomAllToAllV");
REGISTER_OPTYPE_DECLARE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV");

REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign");
REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp");


+ 11
- 0
tests/depends/hccl/src/hccl_stub.cc View File

@@ -42,3 +42,14 @@ HcclResult hcom_reduce_scatter(const char *tag, void *input_ptr, void *output_pt
HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) {
return HCCL_SUCCESS;
}

HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function<void(HcclResult status)> callback) {
return HCCL_SUCCESS;
}

HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params,
std::function<void(HcclResult status)> callback) {
return HCCL_SUCCESS;
}



+ 240
- 108
tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc View File

@@ -1,108 +1,240 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <vector>

#define private public
#define protected public
#include "graph/runtime_inference_context.h"
#include "hybrid/executor/subgraph_context.h"
#include "hybrid/node_executor/hccl/hccl_node_executor.h"
#undef protected
#undef private

using namespace std;
using namespace testing;
namespace ge {
using namespace hybrid;

class UtestHcclNodeExecutor : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
op_desc->SetId(index++);

GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64);
TensorUtils::SetSize(tensor, 64);
vector<int64_t> input_offset;
for (int i = 0; i < in_num; i++) {
op_desc->AddInputDesc(tensor);
input_offset.emplace_back(i * 64);
}
op_desc->SetInputOffset(input_offset);

vector<int64_t> output_offset;
for (int i = 0; i < out_num; i++) {
op_desc->AddOutputDesc(tensor);
output_offset.emplace_back(in_num * 64 + i * 64);
}
op_desc->SetOutputOffset(output_offset);

return graph.AddNode(op_desc);
}

TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0);
std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
node_item->input_start = 0;
node_item->output_start = 0;

GraphItem graph_item;
GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);

auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id));
RuntimeInferenceContext *ctx = nullptr;
RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx);

Shape s({1, 3});
TensorDesc tensor_desc(s);
Tensor tensor(tensor_desc);
std::vector<uint8_t> data = {1, 2, 3, 4};
tensor.SetData(data);
ctx->SetTensor(1, 0, tensor.Clone());

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
vector<HcomRemoteAccessAddrInfo> addr_infos;
shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>();
task->remote_index_ = {1, 0};
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID);

Shape s2({1});
TensorDesc tensor_desc2(s2);
Tensor tensor2(tensor_desc2);
ctx->SetTensor(1, 0, tensor2.Clone());
task->ExtractTensor(*unique_task_context, addr_infos);
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID);
RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id));
}
} // namespace ge
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <vector>
#define private public
#define protected public
#include "graph/runtime_inference_context.h"
#include "hybrid/executor/subgraph_context.h"
#include "hybrid/node_executor/hccl/hccl_node_executor.h"
#undef protected
#undef private
using namespace std;
using namespace testing;
namespace {
const string kHcclSoPath = "../build/tests/depends/hccl/libhccl_stub.so";
}
namespace ge {
using namespace hybrid;
class UtestHcclNodeExecutor : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
op_desc->SetId(index++);
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64);
TensorUtils::SetSize(tensor, 64);
vector<int64_t> input_offset;
for (int i = 0; i < in_num; i++) {
op_desc->AddInputDesc(tensor);
input_offset.emplace_back(i * 64);
}
op_desc->SetInputOffset(input_offset);
vector<int64_t> output_offset;
for (int i = 0; i < out_num; i++) {
op_desc->AddOutputDesc(tensor);
output_offset.emplace_back(in_num * 64 + i * 64);
}
op_desc->SetOutputOffset(output_offset);
return graph.AddNode(op_desc);
}
TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0);
std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
node_item->input_start = 0;
node_item->output_start = 0;
GraphItem graph_item;
GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);
RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id));
RuntimeInferenceContext *ctx = nullptr;
RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx);
Shape s({1, 3});
TensorDesc tensor_desc(s);
Tensor tensor(tensor_desc);
std::vector<uint8_t> data = {1, 2, 3, 4};
tensor.SetData(data);
ctx->SetTensor(1, 0, tensor.Clone());
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
vector<HcomRemoteAccessAddrInfo> addr_infos;
shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>();
task->remote_index_ = {1, 0};
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID);
Shape s2({1});
TensorDesc tensor_desc2(s2);
Tensor tensor2(tensor_desc2);
ctx->SetTensor(1, 0, tensor2.Clone());
task->ExtractTensor(*unique_task_context, addr_infos);
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID);
RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id));
}
TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeModelPtr ge_sub_model = std::make_shared<GeModel>();
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
HybridModel hybrid_model(ge_root_model);
NodePtr node = CreateNode(*graph, "gatheralltoallv", HCOMGATHERALLTOALLV, 4, 2);
std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
hybrid_model.node_items_[node] = std::move(new_node);
node_item->input_start = 0;
node_item->output_start = 0;
GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);
graph_item.total_inputs_ = 4;
graph_item.total_outputs_ = 2;
GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
for (int i=0; i<4; ++i) {
uint64_t value_0 = 512;
TensorValue in_tensor0(&value_0, sizeof(value_0));
subgraph_context.SetInput(*node_item, 0, in_tensor0);
}
uint64_t value_0 = 512;
TensorValue out_tensor0(&value_0, sizeof(value_0));
subgraph_context.SetOutput(*node_item, 0, out_tensor0);
uint64_t value_1 = 512;
TensorValue out_tensor1(&value_1, sizeof(value_1));
subgraph_context.SetOutput(*node_item, 1, out_tensor1);
NodeTaskPtr task = nullptr;
HcclNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
ASSERT_NE(task, nullptr);
auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL);
ASSERT_NE(handle, nullptr);
node_state->GetTaskContext()->handle_ = handle;
std::function<void()> done = []() {};
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS);
if (handle = nullptr) {
dlclose(handle);
}
}
TEST_F(UtestHcclNodeExecutor, alltoallv_execute) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeModelPtr ge_sub_model = std::make_shared<GeModel>();
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
HybridModel hybrid_model(ge_root_model);
NodePtr node = CreateNode(*graph, "alltoallv", HCOMALLTOALLV, 5, 1);
std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
hybrid_model.node_items_[node] = std::move(new_node);
node_item->input_start = 0;
node_item->output_start = 0;
GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);
graph_item.total_inputs_ = 5;
graph_item.total_outputs_ = 1;
GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
for (int i=0; i<5; ++i) {
uint64_t value_0 = 512;
TensorValue in_tensor0(&value_0, sizeof(value_0));
subgraph_context.SetInput(*node_item, 0, in_tensor0);
}
uint64_t value_1 = 512;
TensorValue out_tensor0(&value_1, sizeof(value_1));
subgraph_context.SetOutput(*node_item, 0, out_tensor0);
NodeTaskPtr task = nullptr;
HcclNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
ASSERT_NE(task, nullptr);
auto handle = dlopen(kHcclSoPath.c_str(), RTLD_NOW | RTLD_GLOBAL);
ASSERT_NE(handle, nullptr);
node_state->GetTaskContext()->handle_ = handle;
std::function<void()> done = []() {};
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS);
if (handle = nullptr) {
dlclose(handle);
}
}
} // namespace ge

+ 24
- 0
third_party/fwkacllib/inc/hccl/base.h View File

@@ -123,6 +123,30 @@ struct HcomRemoteAccessAddrInfo {
u64 length; // Memory Length in Bytes
};

struct HcomAllToAllVParams {
void *sendbuf;
void *sendcounts;
void *sdispls;
HcclDataType sendtype;
void *recvbuf;
void *recvcounts;
void *rdispls;
HcclDataType recvtype;
const char *group;
};

struct HcomGatherAllToAllVParams {
void *addrInfo;
void *addrInfoCountPerRank;
void *recvbuf;
void *recvcounts;
void *rdispls;
void *gatheredbuf;
s32 addrLength;
HcclDataType recvtype;
const char *group;
};

#ifdef __cplusplus
}
#endif // __cplusplus


+ 5
- 0
third_party/fwkacllib/inc/hccl/hcom.h View File

@@ -164,6 +164,11 @@ HcclResult HcomExecEnqueueRemoteAccess(const std::string& remoteAccessType,
const std::vector<HcomRemoteAccessAddrInfo>& addrInfos,
std::function<void(HcclResult status)> callback);

HcclResult HcomExecEnqueueAllToAllV(HcomAllToAllVParams params, std::function<void(HcclResult status)> callback);

HcclResult HcomExecEnqueueGatherAllToAllV(HcomGatherAllToAllVParams params,
std::function<void(HcclResult status)> callback);

/**
* @brief Register memories and init resources for remote access.
*


Loading…
Cancel
Save