Browse Source

support int4

tags/v1.3.0
李磊 3 years ago
parent
commit
2e80593540
6 changed files with 97 additions and 18 deletions
  1. +19
    -15
      ge/ge_local_engine/engine/host_cpu_engine.cc
  2. +1
    -1
      metadef
  3. +1
    -1
      parser
  4. +2
    -1
      tests/ut/common/graph/CMakeLists.txt
  5. +1
    -0
      tests/ut/ge/CMakeLists.txt
  6. +73
    -0
      tests/ut/ge/common/host_cpu_engine_unittest.cc

+ 19
- 15
ge/ge_local_engine/engine/host_cpu_engine.cc View File

@@ -26,12 +26,15 @@
#include "common/math/math_util.h" #include "common/math/math_util.h"


namespace { namespace {
#define CREATE_OUTPUT_CASE(DTYPE, TYPE) \
#define CREATE_OUTPUT_CASE(DTYPE) \
case (DTYPE): { \ case (DTYPE): { \
GeTensorPtr ge_tensor = nullptr; \ GeTensorPtr ge_tensor = nullptr; \
if (need_create_flag) { \ if (need_create_flag) { \
uint64_t size = data_num * sizeof(TYPE); \
ge_tensor = MakeShared<GeTensor>(out_desc, size); \
int64_t size = ge::GetSizeInBytes(static_cast<int64_t>(data_num), DTYPE); \
if (size < 0) { \
return INTERNAL_ERROR; \
} \
ge_tensor = MakeShared<GeTensor>(out_desc, static_cast<size_t>(size)); \
GE_CHECK_NOTNULL(ge_tensor); \ GE_CHECK_NOTNULL(ge_tensor); \
GELOGD("node:%s allocate output %zu success, size=%ld", op_desc->GetName().c_str(), i, size); \ GELOGD("node:%s allocate output %zu success, size=%ld", op_desc->GetName().c_str(), i, size); \
ge_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); \ ge_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); \
@@ -180,18 +183,19 @@ Status HostCpuEngine::PrepareOutputs(const ge::ConstOpDescPtr &op_desc,
} }
} }
switch (out_desc.GetDataType()) { switch (out_desc.GetDataType()) {
CREATE_OUTPUT_CASE(DT_BOOL, bool)
CREATE_OUTPUT_CASE(DT_INT8, int8_t)
CREATE_OUTPUT_CASE(DT_INT16, int16_t)
CREATE_OUTPUT_CASE(DT_INT32, int32_t)
CREATE_OUTPUT_CASE(DT_INT64, int64_t)
CREATE_OUTPUT_CASE(DT_UINT8, uint8_t)
CREATE_OUTPUT_CASE(DT_UINT16, uint16_t)
CREATE_OUTPUT_CASE(DT_UINT32, uint32_t)
CREATE_OUTPUT_CASE(DT_UINT64, uint64_t)
CREATE_OUTPUT_CASE(DT_FLOAT16, fp16_t)
CREATE_OUTPUT_CASE(DT_FLOAT, float)
CREATE_OUTPUT_CASE(DT_DOUBLE, double)
CREATE_OUTPUT_CASE(DT_BOOL)
CREATE_OUTPUT_CASE(DT_INT8)
CREATE_OUTPUT_CASE(DT_INT16)
CREATE_OUTPUT_CASE(DT_INT32)
CREATE_OUTPUT_CASE(DT_INT64)
CREATE_OUTPUT_CASE(DT_UINT8)
CREATE_OUTPUT_CASE(DT_UINT16)
CREATE_OUTPUT_CASE(DT_UINT32)
CREATE_OUTPUT_CASE(DT_UINT64)
CREATE_OUTPUT_CASE(DT_FLOAT16)
CREATE_OUTPUT_CASE(DT_FLOAT)
CREATE_OUTPUT_CASE(DT_DOUBLE)
CREATE_OUTPUT_CASE(DT_INT4)
default: default:
GELOGW("data type %s not support.", GELOGW("data type %s not support.",
TypeUtils::DataTypeToSerialString(out_desc.GetDataType()).c_str()); TypeUtils::DataTypeToSerialString(out_desc.GetDataType()).c_str());


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 23718da69af64f8a57051ee64d5515ae1e103c70
Subproject commit 567381faaff179106abafb264ba696f45c4d2b43

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit 9bb03f21773f028b07d5a912db6f176268962c7d
Subproject commit 9226f9532a3884490b03e48df5d7aa02611e21f4

+ 2
- 1
tests/ut/common/graph/CMakeLists.txt View File

@@ -82,6 +82,7 @@ set(SRC_FILES
"${GE_CODE_DIR}/metadef/graph/operator_factory.cc" "${GE_CODE_DIR}/metadef/graph/operator_factory.cc"
"${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" "${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc"
"${GE_CODE_DIR}/metadef/graph/tensor.cc" "${GE_CODE_DIR}/metadef/graph/tensor.cc"
"${GE_CODE_DIR}/metadef/graph/types.cc"
"${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" "${GE_CODE_DIR}/metadef/graph/ge_tensor.cc"
"${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" "${GE_CODE_DIR}/metadef/graph/shape_refiner.cc"
"${GE_CODE_DIR}/metadef/graph/format_refiner.cc" "${GE_CODE_DIR}/metadef/graph/format_refiner.cc"
@@ -115,7 +116,7 @@ target_compile_definitions(ut_libgraph PRIVATE
google=ascend_private google=ascend_private
) )


target_link_libraries(ut_libgraph
target_link_libraries(ut_libgraph
$<BUILD_INTERFACE:intf_pub> $<BUILD_INTERFACE:intf_pub>
gtest gtest
gtest_main gtest_main


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -802,6 +802,7 @@ set(MULTI_PARTS_TEST_FILES
"session/ge_api_unittest.cc" "session/ge_api_unittest.cc"
"session/inner_session_unittest.cc" "session/inner_session_unittest.cc"
"session/session_manager_unittest.cc" "session/session_manager_unittest.cc"
"common/host_cpu_engine_unittest.cc"
) )


set(GENERATOR_TEST_FILES set(GENERATOR_TEST_FILES


+ 73
- 0
tests/ut/ge/common/host_cpu_engine_unittest.cc View File

@@ -0,0 +1,73 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#define protected public
#define private public
#include "ge_local_engine/engine/host_cpu_engine.h"
#undef private
#undef protected

namespace ge {
class UTEST_host_cpu_engine : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

TEST_F(UTEST_host_cpu_engine, PrepareOutputs_success) {
OpDescPtr op_desc = std::make_shared<OpDesc>("name", "type");
op_desc->AddOutputDesc("1", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_BOOL));
op_desc->AddOutputDesc("2", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT8));
op_desc->AddOutputDesc("3", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT16));
op_desc->AddOutputDesc("4", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT32));
op_desc->AddOutputDesc("5", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT64));
op_desc->AddOutputDesc("6", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT8));
op_desc->AddOutputDesc("7", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT16));
op_desc->AddOutputDesc("8", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT32));
op_desc->AddOutputDesc("9", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT64));
op_desc->AddOutputDesc("10", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_FLOAT16));
op_desc->AddOutputDesc("11", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_FLOAT));
op_desc->AddOutputDesc("12", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_DOUBLE));
op_desc->AddOutputDesc("13", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT4));

vector<GeTensorPtr> outputs;
GeTensorPtr value = std::make_shared<GeTensor>();
for (int32_t i = 0; i < 13; i++) {
outputs.push_back(value);
}

map<std::string, Tensor> named_outputs;
auto ret = HostCpuEngine::GetInstance().PrepareOutputs(op_desc, outputs, named_outputs);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(named_outputs.size(), 13);
}

TEST_F(UTEST_host_cpu_engine, PrepareOutputs_need_create_success) {
OpDescPtr op_desc = std::make_shared<OpDesc>("name", "type");
op_desc->AddOutputDesc("output_1", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT32));

vector<GeTensorPtr> outputs;
map<std::string, Tensor> named_outputs;
auto ret = HostCpuEngine::GetInstance().PrepareOutputs(op_desc, outputs, named_outputs);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(named_outputs.size(), 1);
EXPECT_EQ(named_outputs["output_1"].GetSize(), 16);
EXPECT_EQ(named_outputs["output_1"].GetTensorDesc().GetDataType(), DT_INT32);
EXPECT_EQ(named_outputs["output_1"].GetTensorDesc().GetShape().GetShapeSize(), 4);
}
} // namespace ge

Loading…
Cancel
Save