@@ -20,7 +20,6 @@ | |||||
#include "graph/attr_value.h" | #include "graph/attr_value.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/manager/util/hcom_util.h" | #include "graph/manager/util/hcom_util.h" | ||||
#include "graph/runtime_inference_context.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "hccl/hcom.h" | #include "hccl/hcom.h" | ||||
@@ -177,61 +176,8 @@ Status RdmaNodeTask::Init(TaskContext &context) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||||
RuntimeInferenceContext *ctx = nullptr; | |||||
GE_CHK_STATUS_RET( | |||||
RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); | |||||
ge::Tensor remote_tensor; | |||||
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | |||||
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); | |||||
if (data == nullptr) { | |||||
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||||
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); | |||||
return SUCCESS; | |||||
} else { | |||||
REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
context.GetNodeItem().NodeType().c_str()); | |||||
GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
context.GetNodeItem().NodeType().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); | |||||
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { | |||||
REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" | |||||
"and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," | |||||
"number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { | |||||
size_t remote_size = 0; | |||||
for (auto idx = 0; idx < dims.front(); ++idx) { | |||||
FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | |||||
auto line_idx = idx * kVarTableRowCnt; | |||||
remote_size += data[line_idx + kVarTableIdxLen]; | |||||
} | |||||
auto allocator = NpuMemoryAllocator::GetAllocator(); | |||||
GE_CHECK_NOTNULL(allocator); | |||||
AllocationAttr attr; | |||||
attr.SetMemType(RDMA_HBM); | |||||
for (auto i = 0; i < context.NumOutputs(); ++i) { | |||||
GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); | |||||
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); | |||||
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); | |||||
} | |||||
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { | |||||
AllocationAttr attr; | |||||
attr.SetMemType(RDMA_HBM); | |||||
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) | |||||
} | |||||
Status RdmaNodeTask::SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, | |||||
vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||||
TensorValue *tv; | TensorValue *tv; | ||||
if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) { | if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) { | ||||
tv = context.MutableOutput(local_index_); | tv = context.MutableOutput(local_index_); | ||||
@@ -239,7 +185,6 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||||
tv = context.MutableInput(local_index_); | tv = context.MutableInput(local_index_); | ||||
} | } | ||||
GE_CHECK_NOTNULL(tv); | GE_CHECK_NOTNULL(tv); | ||||
auto row_num = dims.front(); | |||||
addr_infos.resize(row_num); | addr_infos.resize(row_num); | ||||
if (skip_flag_) { | if (skip_flag_) { | ||||
int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset"); | int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset"); | ||||
@@ -294,6 +239,65 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||||
RuntimeInferenceContext *ctx = nullptr; | |||||
GE_CHK_STATUS_RET( | |||||
RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); | |||||
ge::Tensor remote_tensor; | |||||
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | |||||
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); | |||||
if (data == nullptr) { | |||||
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||||
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); | |||||
return SUCCESS; | |||||
} else { | |||||
REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
context.GetNodeItem().NodeType().c_str()); | |||||
GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
context.GetNodeItem().NodeType().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); | |||||
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { | |||||
REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" | |||||
"and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," | |||||
"number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { | |||||
size_t remote_size = 0; | |||||
for (auto idx = 0; idx < dims.front(); ++idx) { | |||||
FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | |||||
auto line_idx = idx * kVarTableRowCnt; | |||||
remote_size += data[line_idx + kVarTableIdxLen]; | |||||
} | |||||
auto allocator = NpuMemoryAllocator::GetAllocator(); | |||||
GE_CHECK_NOTNULL(allocator); | |||||
AllocationAttr attr; | |||||
attr.SetMemType(RDMA_HBM); | |||||
for (auto i = 0; i < context.NumOutputs(); ++i) { | |||||
GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); | |||||
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); | |||||
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); | |||||
} | |||||
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { | |||||
AllocationAttr attr; | |||||
attr.SetMemType(RDMA_HBM); | |||||
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) | |||||
} | |||||
auto row_num = dims.front(); | |||||
return SetAddrInfo(context, ctx, data, row_num, addr_infos); | |||||
} | |||||
Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | ||||
GELOGI("[%s] RdmaNodeTask::ExecuteAsync in.", context.GetNodeName()); | GELOGI("[%s] RdmaNodeTask::ExecuteAsync in.", context.GetNodeName()); | ||||
auto HcomExecEnqueueRemoteAccess = | auto HcomExecEnqueueRemoteAccess = | ||||
@@ -18,6 +18,7 @@ | |||||
#define HYBRID_HCCL_NODE_EXECUTOR_H_ | #define HYBRID_HCCL_NODE_EXECUTOR_H_ | ||||
#include "common/opskernel/ge_task_info.h" | #include "common/opskernel/ge_task_info.h" | ||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "graph/runtime_inference_context.h" | |||||
#include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
#include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
@@ -53,6 +54,8 @@ class RdmaNodeTask : public NodeTask { | |||||
Status Init(TaskContext &context) override; | Status Init(TaskContext &context) override; | ||||
private: | private: | ||||
Status SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, | |||||
vector<HcomRemoteAccessAddrInfo> &addr_infos); | |||||
Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos); | Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos); | ||||
std::pair<int64_t, int64_t> remote_index_; | std::pair<int64_t, int64_t> remote_index_; | ||||
std::pair<int64_t, int64_t> offset_index_; | std::pair<int64_t, int64_t> offset_index_; | ||||
@@ -710,6 +710,7 @@ set(PASS_TEST_FILES | |||||
"graph/passes/infershape_pass_unittest.cc" | "graph/passes/infershape_pass_unittest.cc" | ||||
"graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | ||||
"graph/passes/multi_batch_clone_pass_unittest.cc" | "graph/passes/multi_batch_clone_pass_unittest.cc" | ||||
"graph/passes/subgraph_const_migration_pass_unittest.cc" | |||||
"graph/passes/replace_with_empty_const_pass_unittest.cc" | "graph/passes/replace_with_empty_const_pass_unittest.cc" | ||||
"graph/passes/link_gen_mask_nodes_pass_unittest.cc" | "graph/passes/link_gen_mask_nodes_pass_unittest.cc" | ||||
"graph/passes/transpose_transdata_pass_unittest.cc" | "graph/passes/transpose_transdata_pass_unittest.cc" | ||||
@@ -718,7 +719,7 @@ set(PASS_TEST_FILES | |||||
"graph/passes/mark_node_unknown_shape_pass_unittest.cc" | "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | ||||
"graph/passes/reshape_recovery_pass_unittest.cc" | "graph/passes/reshape_recovery_pass_unittest.cc" | ||||
"graph/passes/cast_remove_pass_unittest.cc" | "graph/passes/cast_remove_pass_unittest.cc" | ||||
"graph/passes/memcpy_addr_async_unittest.cc" | |||||
"graph/passes/memcpy_addr_async_unittest.cc" | |||||
"graph/passes/hccl_continuous_pass_unittest.cc" | "graph/passes/hccl_continuous_pass_unittest.cc" | ||||
"graph/passes/hccl_memcpy_pass_unittest.cc" | "graph/passes/hccl_memcpy_pass_unittest.cc" | ||||
@@ -843,6 +844,7 @@ set(HYBRID_TEST_FILES | |||||
"hybrid/model/hybrid_model_builder_unittest.cc" | "hybrid/model/hybrid_model_builder_unittest.cc" | ||||
"hybrid/node_executor/rts/rts_node_task_unittest.cc" | "hybrid/node_executor/rts/rts_node_task_unittest.cc" | ||||
"hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" | "hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" | ||||
"hybrid/node_executor/hccl/hccl_node_executor_unittest.cc" | |||||
"hybrid/executor/hybrid_model_async_executor_unittest.cc" | "hybrid/executor/hybrid_model_async_executor_unittest.cc" | ||||
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | ||||
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | ||||
@@ -0,0 +1,125 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <set> | |||||
#include <string> | |||||
#include "framework/omg/omg_inner_types.h" | |||||
#include "graph/common/local_context.h" | |||||
#include "graph/passes/subgraph_const_migration_pass.h" | |||||
#include "inc/pass_manager.h" | |||||
#include "register/op_registry.h" | |||||
namespace ge { | |||||
class UtestSubgraphConstMigrationPass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
public: | |||||
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||||
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
for (auto i = 0; i < in_num; ++i) { | |||||
op_desc->AddInputDesc(test_desc); | |||||
} | |||||
for (auto i = 0; i < out_num; ++i) { | |||||
op_desc->AddOutputDesc(test_desc); | |||||
} | |||||
if (type == "Const") { | |||||
uint64_t const_value = 101; | |||||
auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||||
AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight); | |||||
} | |||||
return graph->AddNode(op_desc); | |||||
} | |||||
void make_original_graph(const ComputeGraphPtr &graph) { | |||||
auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||||
{ | |||||
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||||
} | |||||
auto const1 = MakeNode(graph, 0, 1, "const1", "Const"); | |||||
{ | |||||
auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||||
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); | |||||
GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor()); | |||||
} | |||||
auto const2 = MakeNode(graph, 0, 1, "const2", "Const"); | |||||
{ | |||||
auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||||
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||||
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3); | |||||
GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor()); | |||||
} | |||||
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); | |||||
GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); | |||||
} | |||||
void make_multibatch_graph(const ComputeGraphPtr &graph) { | |||||
auto index = MakeNode(graph, 1, 1, "index", "Data"); | |||||
auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||||
auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||||
auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||||
AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||||
auto case1 = MakeNode(graph, 4, 1, "case", "Case"); | |||||
GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2)); | |||||
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3)); | |||||
auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput"); | |||||
GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); | |||||
case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic); | |||||
ComputeGraphPtr branch = std::make_shared<ComputeGraph>("test_branch"); | |||||
make_original_graph(branch); | |||||
for (int i = 0; i < 2; ++i) { | |||||
std::string name("_ascend_mbatch_batch_" + std::to_string(i)); | |||||
std::vector<NodePtr> input_nodes; | |||||
std::vector<NodePtr> output_nodes; | |||||
ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes); | |||||
subgraph->SetName(name); | |||||
subgraph->SetParentNode(case1); | |||||
subgraph->SetParentGraph(graph); | |||||
graph->AddSubgraph(subgraph->GetName(), subgraph); | |||||
case1->GetOpDesc()->AddSubgraphName(name); | |||||
case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName()); | |||||
} | |||||
} | |||||
}; | |||||
TEST_F(UtestSubgraphConstMigrationPass, graph_nullptr) { | |||||
PassManager pass_manager; | |||||
pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass); | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
make_multibatch_graph(graph); | |||||
pass_manager.Run(graph); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,108 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gmock/gmock.h> | |||||
#include <gtest/gtest.h> | |||||
#include <vector> | |||||
#define private public | |||||
#define protected public | |||||
#include "graph/runtime_inference_context.h" | |||||
#include "hybrid/executor/subgraph_context.h" | |||||
#include "hybrid/node_executor/hccl/hccl_node_executor.h" | |||||
#undef protected | |||||
#undef private | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
using namespace hybrid; | |||||
class UtestHcclNodeExecutor : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
op_desc->SetStreamId(0); | |||||
static int32_t index = 0; | |||||
op_desc->SetId(index++); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
TensorUtils::SetSize(tensor, 64); | |||||
vector<int64_t> input_offset; | |||||
for (int i = 0; i < in_num; i++) { | |||||
op_desc->AddInputDesc(tensor); | |||||
input_offset.emplace_back(i * 64); | |||||
} | |||||
op_desc->SetInputOffset(input_offset); | |||||
vector<int64_t> output_offset; | |||||
for (int i = 0; i < out_num; i++) { | |||||
op_desc->AddOutputDesc(tensor); | |||||
output_offset.emplace_back(in_num * 64 + i * 64); | |||||
} | |||||
op_desc->SetOutputOffset(output_offset); | |||||
return graph.AddNode(op_desc); | |||||
} | |||||
TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0); | |||||
std::unique_ptr<NodeItem> new_node; | |||||
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
NodeItem *node_item = new_node.get(); | |||||
node_item->input_start = 0; | |||||
node_item->output_start = 0; | |||||
GraphItem graph_item; | |||||
GraphExecutionContext graph_context; | |||||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
ASSERT_NE(node_state, nullptr); | |||||
RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id)); | |||||
RuntimeInferenceContext *ctx = nullptr; | |||||
RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx); | |||||
Shape s({1, 3}); | |||||
TensorDesc tensor_desc(s); | |||||
Tensor tensor(tensor_desc); | |||||
std::vector<uint8_t> data = {1, 2, 3, 4}; | |||||
tensor.SetData(data); | |||||
ctx->SetTensor(1, 0, tensor.Clone()); | |||||
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
vector<HcomRemoteAccessAddrInfo> addr_infos; | |||||
shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | |||||
task->remote_index_ = {1, 0}; | |||||
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
Shape s2({1}); | |||||
TensorDesc tensor_desc2(s2); | |||||
Tensor tensor2(tensor_desc2); | |||||
ctx->SetTensor(1, 0, tensor2.Clone()); | |||||
task->ExtractTensor(*unique_task_context, addr_infos); | |||||
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); | |||||
} | |||||
} // namespace ge |