From 35bc98c540eb075ba32e191004fafe4e5fb5f95e Mon Sep 17 00:00:00 2001 From: "lianghao24@hisilicon.com" Date: Fri, 5 Feb 2021 17:38:22 +0800 Subject: [PATCH] hcomrecieve --- ge/graph/manager/util/hcom_util.cc | 15 ++- tests/ut/ge/CMakeLists.txt | 2 + .../ut/ge/graph/manager/hcom_util_unittest.cc | 97 +++++++++++++++++++ 3 files changed, 106 insertions(+), 8 deletions(-) create mode 100644 tests/ut/ge/graph/manager/hcom_util_unittest.cc diff --git a/ge/graph/manager/util/hcom_util.cc b/ge/graph/manager/util/hcom_util.cc index 50fa9936..53dd9410 100644 --- a/ge/graph/manager/util/hcom_util.cc +++ b/ge/graph/manager/util/hcom_util.cc @@ -84,15 +84,14 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, HcclDataType int32_t size = 0; GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(data_type, size), "GetHcomCount: GetHcclTypeSize fail!"); if (op_desc->GetType() == HCOMRECEIVE) { - vector shape_dims; - bool ret = ge::AttrUtils::GetListInt(op_desc, HCOM_ATTR_SHAPE, shape_dims); - if (ret == false) { - GELOGE(PARAM_INVALID, "op:HcomReceive, op desc no attr: shape."); - return PARAM_INVALID; + for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { + int64_t output_size = 0; + GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i)); + GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), output_size), + "Get size from TensorDesc failed, op: %s, output index: %zu.", op_desc->GetName().c_str(), i); + output_size = (output_size + align_size - 1) / align_size * align_size; + total_size += output_size; } - ge::GeShape shape = ge::GeShape(shape_dims); - int64_t input_size = shape.GetShapeSize() * size; - total_size = (input_size + align_size - 1) / align_size * align_size; } else { for (size_t i = 0; i < op_desc->GetInputsSize(); i++) { int64_t input_size = 0; diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index ba1bfaac..7c49c0a7 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -353,6 +353,7 @@ set(COMMON_FORMAT_SRC_FILES "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc" "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc" "${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc" + "${GE_CODE_DIR}/ge/graph/manager/util/hcom_util.cc" ) set(GRAPH_OPTIMIZE_COMMON_SRC_FILES @@ -750,6 +751,7 @@ set(MULTI_PARTS_TEST_FILES "graph/build/logical_stream_allocator_unittest.cc" "graph/build/mem_assigner_unittest.cc" "graph/preprocess/graph_preprocess_unittest.cc" + "graph/manager/hcom_util_unittest.cc" "session/omg_omg_unittest.cc" ) diff --git a/tests/ut/ge/graph/manager/hcom_util_unittest.cc b/tests/ut/ge/graph/manager/hcom_util_unittest.cc new file mode 100644 index 00000000..9f104f5f --- /dev/null +++ b/tests/ut/ge/graph/manager/hcom_util_unittest.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "common/ge_inner_error_codes.h" +#include "common/types.h" +#include "common/util.h" +#include "graph/utils/attr_utils.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/passes/addn_pass.h" + +#define private public +#define protected public +#include "graph/manager/util/hcom_util.h" +#include "ge/ge_api.h" +#undef private +#undef protected + +using namespace std; + +namespace ge { +namespace { +GeTensorDescPtr CreateTensorDesc(std::initializer_list shape, Format format = FORMAT_NCHW, + DataType data_type = DT_FLOAT) { + GeShape ge_shape{vector(shape)}; + GeTensorDescPtr tensor_desc = std::make_shared(); + tensor_desc->SetShape(ge_shape); + tensor_desc->SetFormat(format); + tensor_desc->SetDataType(data_type); + return tensor_desc; +} + +class NodeBuilder { + public: + NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared(name, type); } + + NodeBuilder &AddInputDesc(std::initializer_list shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, + DataType data_type = DT_FLOAT) { + op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); + return *this; + } + + NodeBuilder &AddOutputDesc(std::initializer_list shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, + DataType data_type = DT_FLOAT) { + op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); + return *this; + } + + NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) { + op_desc_->AddOutputDesc(tensor_desc->Clone()); + return *this; + } + + NodePtr Build(const ComputeGraphPtr &graph) { + NodePtr node = graph->AddNode(op_desc_); + return node; + } + + private: + OpDescPtr op_desc_; +}; +} // namespace + +class UtestHcomUtil : public testing::Test { + protected: + void SetUp() { + } + void TearDown() { + } +}; + +TEST_F(UtestHcomUtil, test_GetHcomCount_succ) { + ComputeGraphPtr graph = std::make_shared("test"); + NodePtr node = NodeBuilder("node", HCOMRECEIVE).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph); + auto op_desc = node->GetOpDesc(); + + HcomOmeUtil hcom_ome_util; + int count = 0; + auto ret = hcom_ome_util.GetHcomCount(op_desc, HCCL_DATA_TYPE_FP32, true, count); + EXPECT_EQ(ret, 0); +} +} // namespace ge \ No newline at end of file