Browse Source

!1094 hcomreceive

From: @dimitri_rose
Reviewed-by: @wqtshg,@ji_chen
Signed-off-by: @ji_chen
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
b281a4e0e8
3 changed files with 106 additions and 8 deletions
  1. +7
    -8
      ge/graph/manager/util/hcom_util.cc
  2. +2
    -0
      tests/ut/ge/CMakeLists.txt
  3. +97
    -0
      tests/ut/ge/graph/manager/hcom_util_unittest.cc

+ 7
- 8
ge/graph/manager/util/hcom_util.cc View File

@@ -84,15 +84,14 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, HcclDataType
int32_t size = 0;
GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(data_type, size), "GetHcomCount: GetHcclTypeSize fail!");
if (op_desc->GetType() == HCOMRECEIVE) {
vector<int64_t> shape_dims;
bool ret = ge::AttrUtils::GetListInt(op_desc, HCOM_ATTR_SHAPE, shape_dims);
if (ret == false) {
GELOGE(PARAM_INVALID, "op:HcomReceive, op desc no attr: shape.");
return PARAM_INVALID;
for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) {
int64_t output_size = 0;
GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i));
GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), output_size),
"Get size from TensorDesc failed, op: %s, output index: %zu.", op_desc->GetName().c_str(), i);
output_size = (output_size + align_size - 1) / align_size * align_size;
total_size += output_size;
}
ge::GeShape shape = ge::GeShape(shape_dims);
int64_t input_size = shape.GetShapeSize() * size;
total_size = (input_size + align_size - 1) / align_size * align_size;
} else {
for (size_t i = 0; i < op_desc->GetInputsSize(); i++) {
int64_t input_size = 0;


+ 2
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -353,6 +353,7 @@ set(COMMON_FORMAT_SRC_FILES
"${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc"
"${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc"
"${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc"
"${GE_CODE_DIR}/ge/graph/manager/util/hcom_util.cc"
)

set(GRAPH_OPTIMIZE_COMMON_SRC_FILES
@@ -750,6 +751,7 @@ set(MULTI_PARTS_TEST_FILES
"graph/build/logical_stream_allocator_unittest.cc"
"graph/build/mem_assigner_unittest.cc"
"graph/preprocess/graph_preprocess_unittest.cc"
"graph/manager/hcom_util_unittest.cc"
"session/omg_omg_unittest.cc"
)



+ 97
- 0
tests/ut/ge/graph/manager/hcom_util_unittest.cc View File

@@ -0,0 +1,97 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <memory>

#include "common/ge_inner_error_codes.h"
#include "common/types.h"
#include "common/util.h"
#include "graph/utils/attr_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/passes/addn_pass.h"

#define private public
#define protected public
#include "graph/manager/util/hcom_util.h"
#include "ge/ge_api.h"
#undef private
#undef protected

using namespace std;

namespace ge {
namespace {
GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, Format format = FORMAT_NCHW,
DataType data_type = DT_FLOAT) {
GeShape ge_shape{vector<int64_t>(shape)};
GeTensorDescPtr tensor_desc = std::make_shared<GeTensorDesc>();
tensor_desc->SetShape(ge_shape);
tensor_desc->SetFormat(format);
tensor_desc->SetDataType(data_type);
return tensor_desc;
}

class NodeBuilder {
public:
NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }

NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW,
DataType data_type = DT_FLOAT) {
op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
return *this;
}

NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW,
DataType data_type = DT_FLOAT) {
op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
return *this;
}

NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) {
op_desc_->AddOutputDesc(tensor_desc->Clone());
return *this;
}

NodePtr Build(const ComputeGraphPtr &graph) {
NodePtr node = graph->AddNode(op_desc_);
return node;
}

private:
OpDescPtr op_desc_;
};
} // namespace

class UtestHcomUtil : public testing::Test {
protected:
void SetUp() {
}
void TearDown() {
}
};

TEST_F(UtestHcomUtil, test_GetHcomCount_succ) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr node = NodeBuilder("node", HCOMRECEIVE).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
auto op_desc = node->GetOpDesc();

HcomOmeUtil hcom_ome_util;
int count = 0;
auto ret = hcom_ome_util.GetHcomCount(op_desc, HCCL_DATA_TYPE_FP32, true, count);
EXPECT_EQ(ret, 0);
}
} // namespace ge

Loading…
Cancel
Save