| @@ -20,7 +20,6 @@ | |||||
| #include "graph/attr_value.h" | #include "graph/attr_value.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/manager/util/hcom_util.h" | #include "graph/manager/util/hcom_util.h" | ||||
| #include "graph/runtime_inference_context.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| #include "hccl/hcom.h" | #include "hccl/hcom.h" | ||||
| @@ -177,61 +176,8 @@ Status RdmaNodeTask::Init(TaskContext &context) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||||
| RuntimeInferenceContext *ctx = nullptr; | |||||
| GE_CHK_STATUS_RET( | |||||
| RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); | |||||
| ge::Tensor remote_tensor; | |||||
| GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | |||||
| auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); | |||||
| if (data == nullptr) { | |||||
| if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||||
| GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
| context.GetNodeItem().NodeType().c_str()); | |||||
| GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
| context.GetNodeItem().NodeType().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); | |||||
| if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { | |||||
| REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" | |||||
| "and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
| dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
| context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," | |||||
| "number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
| dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
| context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { | |||||
| size_t remote_size = 0; | |||||
| for (auto idx = 0; idx < dims.front(); ++idx) { | |||||
| FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | |||||
| auto line_idx = idx * kVarTableRowCnt; | |||||
| remote_size += data[line_idx + kVarTableIdxLen]; | |||||
| } | |||||
| auto allocator = NpuMemoryAllocator::GetAllocator(); | |||||
| GE_CHECK_NOTNULL(allocator); | |||||
| AllocationAttr attr; | |||||
| attr.SetMemType(RDMA_HBM); | |||||
| for (auto i = 0; i < context.NumOutputs(); ++i) { | |||||
| GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); | |||||
| auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); | |||||
| GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); | |||||
| } | |||||
| } else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { | |||||
| AllocationAttr attr; | |||||
| attr.SetMemType(RDMA_HBM); | |||||
| GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) | |||||
| } | |||||
| Status RdmaNodeTask::SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, | |||||
| vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||||
| TensorValue *tv; | TensorValue *tv; | ||||
| if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) { | if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) { | ||||
| tv = context.MutableOutput(local_index_); | tv = context.MutableOutput(local_index_); | ||||
| @@ -239,7 +185,6 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||||
| tv = context.MutableInput(local_index_); | tv = context.MutableInput(local_index_); | ||||
| } | } | ||||
| GE_CHECK_NOTNULL(tv); | GE_CHECK_NOTNULL(tv); | ||||
| auto row_num = dims.front(); | |||||
| addr_infos.resize(row_num); | addr_infos.resize(row_num); | ||||
| if (skip_flag_) { | if (skip_flag_) { | ||||
| int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset"); | int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset"); | ||||
| @@ -294,6 +239,65 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | |||||
| RuntimeInferenceContext *ctx = nullptr; | |||||
| GE_CHK_STATUS_RET( | |||||
| RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); | |||||
| ge::Tensor remote_tensor; | |||||
| GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | |||||
| auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); | |||||
| if (data == nullptr) { | |||||
| if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) { | |||||
| GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } else { | |||||
| REPORT_INNER_ERROR("E19999", "Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
| context.GetNodeItem().NodeType().c_str()); | |||||
| GELOGE(FAILED, "[Find][NodeType]Tensor data is nullptr. and kRdmaScatterTypes not contain %s", | |||||
| context.GetNodeItem().NodeType().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); | |||||
| if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { | |||||
| REPORT_INNER_ERROR("E19999", "Variable table shape check failed, number of shape dims:%zu not equal expect:%zu" | |||||
| "and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
| dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
| context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param]Variable table shape check failed," | |||||
| "number of shape dims:%zu not equal expect:%zu and shape dims back:%zu not equal expect:%zu, node:%s(%s)", | |||||
| dims.size(), kVarTableDims, dims.back(), kVarTableRowCnt, | |||||
| context.GetNodeName(), context.GetNodeItem().NodeType().c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) { | |||||
| size_t remote_size = 0; | |||||
| for (auto idx = 0; idx < dims.front(); ++idx) { | |||||
| FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | |||||
| auto line_idx = idx * kVarTableRowCnt; | |||||
| remote_size += data[line_idx + kVarTableIdxLen]; | |||||
| } | |||||
| auto allocator = NpuMemoryAllocator::GetAllocator(); | |||||
| GE_CHECK_NOTNULL(allocator); | |||||
| AllocationAttr attr; | |||||
| attr.SetMemType(RDMA_HBM); | |||||
| for (auto i = 0; i < context.NumOutputs(); ++i) { | |||||
| GELOGD("Allocate rdma memory for node %s, size: %zu", context.GetNodeName(), remote_size); | |||||
| auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); | |||||
| GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); | |||||
| } | |||||
| } else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) { | |||||
| AllocationAttr attr; | |||||
| attr.SetMemType(RDMA_HBM); | |||||
| GE_CHK_STATUS_RET(context.AllocateOutputs(&attr)) | |||||
| } | |||||
| auto row_num = dims.front(); | |||||
| return SetAddrInfo(context, ctx, data, row_num, addr_infos); | |||||
| } | |||||
| Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | ||||
| GELOGI("[%s] RdmaNodeTask::ExecuteAsync in.", context.GetNodeName()); | GELOGI("[%s] RdmaNodeTask::ExecuteAsync in.", context.GetNodeName()); | ||||
| auto HcomExecEnqueueRemoteAccess = | auto HcomExecEnqueueRemoteAccess = | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define HYBRID_HCCL_NODE_EXECUTOR_H_ | #define HYBRID_HCCL_NODE_EXECUTOR_H_ | ||||
| #include "common/opskernel/ge_task_info.h" | #include "common/opskernel/ge_task_info.h" | ||||
| #include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
| #include "graph/runtime_inference_context.h" | |||||
| #include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| @@ -53,6 +54,8 @@ class RdmaNodeTask : public NodeTask { | |||||
| Status Init(TaskContext &context) override; | Status Init(TaskContext &context) override; | ||||
| private: | private: | ||||
| Status SetAddrInfo(TaskContext &context, RuntimeInferenceContext *ctx, uint64_t *data, int64_t row_num, | |||||
| vector<HcomRemoteAccessAddrInfo> &addr_infos); | |||||
| Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos); | Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos); | ||||
| std::pair<int64_t, int64_t> remote_index_; | std::pair<int64_t, int64_t> remote_index_; | ||||
| std::pair<int64_t, int64_t> offset_index_; | std::pair<int64_t, int64_t> offset_index_; | ||||
| @@ -710,6 +710,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/infershape_pass_unittest.cc" | "graph/passes/infershape_pass_unittest.cc" | ||||
| "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | "graph/passes/mark_force_unknown_for_cond_pass_unittest.cc" | ||||
| "graph/passes/multi_batch_clone_pass_unittest.cc" | "graph/passes/multi_batch_clone_pass_unittest.cc" | ||||
| "graph/passes/subgraph_const_migration_pass_unittest.cc" | |||||
| "graph/passes/replace_with_empty_const_pass_unittest.cc" | "graph/passes/replace_with_empty_const_pass_unittest.cc" | ||||
| "graph/passes/link_gen_mask_nodes_pass_unittest.cc" | "graph/passes/link_gen_mask_nodes_pass_unittest.cc" | ||||
| "graph/passes/transpose_transdata_pass_unittest.cc" | "graph/passes/transpose_transdata_pass_unittest.cc" | ||||
| @@ -718,7 +719,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | ||||
| "graph/passes/reshape_recovery_pass_unittest.cc" | "graph/passes/reshape_recovery_pass_unittest.cc" | ||||
| "graph/passes/cast_remove_pass_unittest.cc" | "graph/passes/cast_remove_pass_unittest.cc" | ||||
| "graph/passes/memcpy_addr_async_unittest.cc" | |||||
| "graph/passes/memcpy_addr_async_unittest.cc" | |||||
| "graph/passes/hccl_continuous_pass_unittest.cc" | "graph/passes/hccl_continuous_pass_unittest.cc" | ||||
| "graph/passes/hccl_memcpy_pass_unittest.cc" | "graph/passes/hccl_memcpy_pass_unittest.cc" | ||||
| @@ -843,6 +844,7 @@ set(HYBRID_TEST_FILES | |||||
| "hybrid/model/hybrid_model_builder_unittest.cc" | "hybrid/model/hybrid_model_builder_unittest.cc" | ||||
| "hybrid/node_executor/rts/rts_node_task_unittest.cc" | "hybrid/node_executor/rts/rts_node_task_unittest.cc" | ||||
| "hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" | "hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" | ||||
| "hybrid/node_executor/hccl/hccl_node_executor_unittest.cc" | |||||
| "hybrid/executor/hybrid_model_async_executor_unittest.cc" | "hybrid/executor/hybrid_model_async_executor_unittest.cc" | ||||
| "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | ||||
| "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | ||||
| @@ -0,0 +1,125 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "framework/omg/omg_inner_types.h" | |||||
| #include "graph/common/local_context.h" | |||||
| #include "graph/passes/subgraph_const_migration_pass.h" | |||||
| #include "inc/pass_manager.h" | |||||
| #include "register/op_registry.h" | |||||
| namespace ge { | |||||
| class UtestSubgraphConstMigrationPass : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| public: | |||||
| NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||||
| GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (auto i = 0; i < in_num; ++i) { | |||||
| op_desc->AddInputDesc(test_desc); | |||||
| } | |||||
| for (auto i = 0; i < out_num; ++i) { | |||||
| op_desc->AddOutputDesc(test_desc); | |||||
| } | |||||
| if (type == "Const") { | |||||
| uint64_t const_value = 101; | |||||
| auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||||
| AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight); | |||||
| } | |||||
| return graph->AddNode(op_desc); | |||||
| } | |||||
| void make_original_graph(const ComputeGraphPtr &graph) { | |||||
| auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||||
| { | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||||
| } | |||||
| auto const1 = MakeNode(graph, 0, 1, "const1", "Const"); | |||||
| { | |||||
| auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); | |||||
| GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor()); | |||||
| } | |||||
| auto const2 = MakeNode(graph, 0, 1, "const2", "Const"); | |||||
| { | |||||
| auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3); | |||||
| GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor()); | |||||
| } | |||||
| auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); | |||||
| GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); | |||||
| } | |||||
| void make_multibatch_graph(const ComputeGraphPtr &graph) { | |||||
| auto index = MakeNode(graph, 1, 1, "index", "Data"); | |||||
| auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||||
| auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||||
| auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||||
| auto case1 = MakeNode(graph, 4, 1, "case", "Case"); | |||||
| GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2)); | |||||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3)); | |||||
| auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput"); | |||||
| GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
| AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); | |||||
| case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic); | |||||
| ComputeGraphPtr branch = std::make_shared<ComputeGraph>("test_branch"); | |||||
| make_original_graph(branch); | |||||
| for (int i = 0; i < 2; ++i) { | |||||
| std::string name("_ascend_mbatch_batch_" + std::to_string(i)); | |||||
| std::vector<NodePtr> input_nodes; | |||||
| std::vector<NodePtr> output_nodes; | |||||
| ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes); | |||||
| subgraph->SetName(name); | |||||
| subgraph->SetParentNode(case1); | |||||
| subgraph->SetParentGraph(graph); | |||||
| graph->AddSubgraph(subgraph->GetName(), subgraph); | |||||
| case1->GetOpDesc()->AddSubgraphName(name); | |||||
| case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName()); | |||||
| } | |||||
| } | |||||
| }; | |||||
| TEST_F(UtestSubgraphConstMigrationPass, graph_nullptr) { | |||||
| PassManager pass_manager; | |||||
| pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass); | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| make_multibatch_graph(graph); | |||||
| pass_manager.Run(graph); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,108 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gmock/gmock.h> | |||||
| #include <gtest/gtest.h> | |||||
| #include <vector> | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "graph/runtime_inference_context.h" | |||||
| #include "hybrid/executor/subgraph_context.h" | |||||
| #include "hybrid/node_executor/hccl/hccl_node_executor.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| namespace ge { | |||||
| using namespace hybrid; | |||||
| class UtestHcclNodeExecutor : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| op_desc->SetStreamId(0); | |||||
| static int32_t index = 0; | |||||
| op_desc->SetId(index++); | |||||
| GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
| TensorUtils::SetSize(tensor, 64); | |||||
| vector<int64_t> input_offset; | |||||
| for (int i = 0; i < in_num; i++) { | |||||
| op_desc->AddInputDesc(tensor); | |||||
| input_offset.emplace_back(i * 64); | |||||
| } | |||||
| op_desc->SetInputOffset(input_offset); | |||||
| vector<int64_t> output_offset; | |||||
| for (int i = 0; i < out_num; i++) { | |||||
| op_desc->AddOutputDesc(tensor); | |||||
| output_offset.emplace_back(in_num * 64 + i * 64); | |||||
| } | |||||
| op_desc->SetOutputOffset(output_offset); | |||||
| return graph.AddNode(op_desc); | |||||
| } | |||||
| TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| NodePtr node = CreateNode(*graph, "hcom", HCOMREMOTEREAD, 0, 0); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| GraphItem graph_item; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| RuntimeInferenceContext::CreateContext(std::to_string(graph_context.context_id)); | |||||
| RuntimeInferenceContext *ctx = nullptr; | |||||
| RuntimeInferenceContext::GetContext(std::to_string(graph_context.context_id), &ctx); | |||||
| Shape s({1, 3}); | |||||
| TensorDesc tensor_desc(s); | |||||
| Tensor tensor(tensor_desc); | |||||
| std::vector<uint8_t> data = {1, 2, 3, 4}; | |||||
| tensor.SetData(data); | |||||
| ctx->SetTensor(1, 0, tensor.Clone()); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| vector<HcomRemoteAccessAddrInfo> addr_infos; | |||||
| shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | |||||
| task->remote_index_ = {1, 0}; | |||||
| ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
| Shape s2({1}); | |||||
| TensorDesc tensor_desc2(s2); | |||||
| Tensor tensor2(tensor_desc2); | |||||
| ctx->SetTensor(1, 0, tensor2.Clone()); | |||||
| task->ExtractTensor(*unique_task_context, addr_infos); | |||||
| ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||||
| RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); | |||||
| } | |||||
| } // namespace ge | |||||