From 2e805935408db15cf9866507dde20821d96388ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=A3=8A?= Date: Wed, 19 May 2021 11:38:52 +0800 Subject: [PATCH] support int4 --- ge/ge_local_engine/engine/host_cpu_engine.cc | 34 +++++---- metadef | 2 +- parser | 2 +- tests/ut/common/graph/CMakeLists.txt | 3 +- tests/ut/ge/CMakeLists.txt | 1 + .../ut/ge/common/host_cpu_engine_unittest.cc | 73 +++++++++++++++++++ 6 files changed, 97 insertions(+), 18 deletions(-) create mode 100644 tests/ut/ge/common/host_cpu_engine_unittest.cc diff --git a/ge/ge_local_engine/engine/host_cpu_engine.cc b/ge/ge_local_engine/engine/host_cpu_engine.cc index 8f01a166..f3a14317 100755 --- a/ge/ge_local_engine/engine/host_cpu_engine.cc +++ b/ge/ge_local_engine/engine/host_cpu_engine.cc @@ -26,12 +26,15 @@ #include "common/math/math_util.h" namespace { -#define CREATE_OUTPUT_CASE(DTYPE, TYPE) \ +#define CREATE_OUTPUT_CASE(DTYPE) \ case (DTYPE): { \ GeTensorPtr ge_tensor = nullptr; \ if (need_create_flag) { \ - uint64_t size = data_num * sizeof(TYPE); \ - ge_tensor = MakeShared(out_desc, size); \ + int64_t size = ge::GetSizeInBytes(static_cast(data_num), DTYPE); \ + if (size < 0) { \ + return INTERNAL_ERROR; \ + } \ + ge_tensor = MakeShared(out_desc, static_cast(size)); \ GE_CHECK_NOTNULL(ge_tensor); \ GELOGD("node:%s allocate output %zu success, size=%ld", op_desc->GetName().c_str(), i, size); \ ge_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); \ @@ -180,18 +183,19 @@ Status HostCpuEngine::PrepareOutputs(const ge::ConstOpDescPtr &op_desc, } } switch (out_desc.GetDataType()) { - CREATE_OUTPUT_CASE(DT_BOOL, bool) - CREATE_OUTPUT_CASE(DT_INT8, int8_t) - CREATE_OUTPUT_CASE(DT_INT16, int16_t) - CREATE_OUTPUT_CASE(DT_INT32, int32_t) - CREATE_OUTPUT_CASE(DT_INT64, int64_t) - CREATE_OUTPUT_CASE(DT_UINT8, uint8_t) - CREATE_OUTPUT_CASE(DT_UINT16, uint16_t) - CREATE_OUTPUT_CASE(DT_UINT32, uint32_t) - CREATE_OUTPUT_CASE(DT_UINT64, uint64_t) - CREATE_OUTPUT_CASE(DT_FLOAT16, fp16_t) - CREATE_OUTPUT_CASE(DT_FLOAT, float) - CREATE_OUTPUT_CASE(DT_DOUBLE, double) + CREATE_OUTPUT_CASE(DT_BOOL) + CREATE_OUTPUT_CASE(DT_INT8) + CREATE_OUTPUT_CASE(DT_INT16) + CREATE_OUTPUT_CASE(DT_INT32) + CREATE_OUTPUT_CASE(DT_INT64) + CREATE_OUTPUT_CASE(DT_UINT8) + CREATE_OUTPUT_CASE(DT_UINT16) + CREATE_OUTPUT_CASE(DT_UINT32) + CREATE_OUTPUT_CASE(DT_UINT64) + CREATE_OUTPUT_CASE(DT_FLOAT16) + CREATE_OUTPUT_CASE(DT_FLOAT) + CREATE_OUTPUT_CASE(DT_DOUBLE) + CREATE_OUTPUT_CASE(DT_INT4) default: GELOGW("data type %s not support.", TypeUtils::DataTypeToSerialString(out_desc.GetDataType()).c_str()); diff --git a/metadef b/metadef index 23718da6..567381fa 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 23718da69af64f8a57051ee64d5515ae1e103c70 +Subproject commit 567381faaff179106abafb264ba696f45c4d2b43 diff --git a/parser b/parser index 9bb03f21..9226f953 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit 9bb03f21773f028b07d5a912db6f176268962c7d +Subproject commit 9226f9532a3884490b03e48df5d7aa02611e21f4 diff --git a/tests/ut/common/graph/CMakeLists.txt b/tests/ut/common/graph/CMakeLists.txt index a957298a..73780967 100644 --- a/tests/ut/common/graph/CMakeLists.txt +++ b/tests/ut/common/graph/CMakeLists.txt @@ -82,6 +82,7 @@ set(SRC_FILES "${GE_CODE_DIR}/metadef/graph/operator_factory.cc" "${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" "${GE_CODE_DIR}/metadef/graph/tensor.cc" + "${GE_CODE_DIR}/metadef/graph/types.cc" "${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" "${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" "${GE_CODE_DIR}/metadef/graph/format_refiner.cc" @@ -115,7 +116,7 @@ target_compile_definitions(ut_libgraph PRIVATE google=ascend_private ) -target_link_libraries(ut_libgraph +target_link_libraries(ut_libgraph $ gtest gtest_main diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 16f3672b..37824c4a 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -802,6 +802,7 @@ set(MULTI_PARTS_TEST_FILES "session/ge_api_unittest.cc" "session/inner_session_unittest.cc" "session/session_manager_unittest.cc" + "common/host_cpu_engine_unittest.cc" ) set(GENERATOR_TEST_FILES diff --git a/tests/ut/ge/common/host_cpu_engine_unittest.cc b/tests/ut/ge/common/host_cpu_engine_unittest.cc new file mode 100644 index 00000000..1a414c86 --- /dev/null +++ b/tests/ut/ge/common/host_cpu_engine_unittest.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define protected public +#define private public +#include "ge_local_engine/engine/host_cpu_engine.h" +#undef private +#undef protected + +namespace ge { +class UTEST_host_cpu_engine : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UTEST_host_cpu_engine, PrepareOutputs_success) { + OpDescPtr op_desc = std::make_shared("name", "type"); + op_desc->AddOutputDesc("1", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_BOOL)); + op_desc->AddOutputDesc("2", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT8)); + op_desc->AddOutputDesc("3", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT16)); + op_desc->AddOutputDesc("4", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT32)); + op_desc->AddOutputDesc("5", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT64)); + op_desc->AddOutputDesc("6", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT8)); + op_desc->AddOutputDesc("7", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT16)); + op_desc->AddOutputDesc("8", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT32)); + op_desc->AddOutputDesc("9", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT64)); + op_desc->AddOutputDesc("10", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_FLOAT16)); + op_desc->AddOutputDesc("11", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_FLOAT)); + op_desc->AddOutputDesc("12", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_DOUBLE)); + op_desc->AddOutputDesc("13", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT4)); + + vector outputs; + GeTensorPtr value = std::make_shared(); + for (int32_t i = 0; i < 13; i++) { + outputs.push_back(value); + } + + map named_outputs; + auto ret = HostCpuEngine::GetInstance().PrepareOutputs(op_desc, outputs, named_outputs); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(named_outputs.size(), 13); +} + +TEST_F(UTEST_host_cpu_engine, PrepareOutputs_need_create_success) { + OpDescPtr op_desc = std::make_shared("name", "type"); + op_desc->AddOutputDesc("output_1", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT32)); + + vector outputs; + map named_outputs; + auto ret = HostCpuEngine::GetInstance().PrepareOutputs(op_desc, outputs, named_outputs); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(named_outputs.size(), 1); + EXPECT_EQ(named_outputs["output_1"].GetSize(), 16); + EXPECT_EQ(named_outputs["output_1"].GetTensorDesc().GetDataType(), DT_INT32); + EXPECT_EQ(named_outputs["output_1"].GetTensorDesc().GetShape().GetShapeSize(), 4); +} +} // namespace ge \ No newline at end of file