| @@ -26,12 +26,15 @@ | |||||
| #include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
| namespace { | namespace { | ||||
| #define CREATE_OUTPUT_CASE(DTYPE, TYPE) \ | |||||
| #define CREATE_OUTPUT_CASE(DTYPE) \ | |||||
| case (DTYPE): { \ | case (DTYPE): { \ | ||||
| GeTensorPtr ge_tensor = nullptr; \ | GeTensorPtr ge_tensor = nullptr; \ | ||||
| if (need_create_flag) { \ | if (need_create_flag) { \ | ||||
| uint64_t size = data_num * sizeof(TYPE); \ | |||||
| ge_tensor = MakeShared<GeTensor>(out_desc, size); \ | |||||
| int64_t size = ge::GetSizeInBytes(static_cast<int64_t>(data_num), DTYPE); \ | |||||
| if (size < 0) { \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } \ | |||||
| ge_tensor = MakeShared<GeTensor>(out_desc, static_cast<size_t>(size)); \ | |||||
| GE_CHECK_NOTNULL(ge_tensor); \ | GE_CHECK_NOTNULL(ge_tensor); \ | ||||
| GELOGD("node:%s allocate output %zu success, size=%ld", op_desc->GetName().c_str(), i, size); \ | GELOGD("node:%s allocate output %zu success, size=%ld", op_desc->GetName().c_str(), i, size); \ | ||||
| ge_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); \ | ge_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); \ | ||||
| @@ -180,18 +183,19 @@ Status HostCpuEngine::PrepareOutputs(const ge::ConstOpDescPtr &op_desc, | |||||
| } | } | ||||
| } | } | ||||
| switch (out_desc.GetDataType()) { | switch (out_desc.GetDataType()) { | ||||
| CREATE_OUTPUT_CASE(DT_BOOL, bool) | |||||
| CREATE_OUTPUT_CASE(DT_INT8, int8_t) | |||||
| CREATE_OUTPUT_CASE(DT_INT16, int16_t) | |||||
| CREATE_OUTPUT_CASE(DT_INT32, int32_t) | |||||
| CREATE_OUTPUT_CASE(DT_INT64, int64_t) | |||||
| CREATE_OUTPUT_CASE(DT_UINT8, uint8_t) | |||||
| CREATE_OUTPUT_CASE(DT_UINT16, uint16_t) | |||||
| CREATE_OUTPUT_CASE(DT_UINT32, uint32_t) | |||||
| CREATE_OUTPUT_CASE(DT_UINT64, uint64_t) | |||||
| CREATE_OUTPUT_CASE(DT_FLOAT16, fp16_t) | |||||
| CREATE_OUTPUT_CASE(DT_FLOAT, float) | |||||
| CREATE_OUTPUT_CASE(DT_DOUBLE, double) | |||||
| CREATE_OUTPUT_CASE(DT_BOOL) | |||||
| CREATE_OUTPUT_CASE(DT_INT8) | |||||
| CREATE_OUTPUT_CASE(DT_INT16) | |||||
| CREATE_OUTPUT_CASE(DT_INT32) | |||||
| CREATE_OUTPUT_CASE(DT_INT64) | |||||
| CREATE_OUTPUT_CASE(DT_UINT8) | |||||
| CREATE_OUTPUT_CASE(DT_UINT16) | |||||
| CREATE_OUTPUT_CASE(DT_UINT32) | |||||
| CREATE_OUTPUT_CASE(DT_UINT64) | |||||
| CREATE_OUTPUT_CASE(DT_FLOAT16) | |||||
| CREATE_OUTPUT_CASE(DT_FLOAT) | |||||
| CREATE_OUTPUT_CASE(DT_DOUBLE) | |||||
| CREATE_OUTPUT_CASE(DT_INT4) | |||||
| default: | default: | ||||
| GELOGW("data type %s not support.", | GELOGW("data type %s not support.", | ||||
| TypeUtils::DataTypeToSerialString(out_desc.GetDataType()).c_str()); | TypeUtils::DataTypeToSerialString(out_desc.GetDataType()).c_str()); | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit 23718da69af64f8a57051ee64d5515ae1e103c70 | |||||
| Subproject commit 567381faaff179106abafb264ba696f45c4d2b43 | |||||
| @@ -1 +1 @@ | |||||
| Subproject commit 9bb03f21773f028b07d5a912db6f176268962c7d | |||||
| Subproject commit 9226f9532a3884490b03e48df5d7aa02611e21f4 | |||||
| @@ -82,6 +82,7 @@ set(SRC_FILES | |||||
| "${GE_CODE_DIR}/metadef/graph/operator_factory.cc" | "${GE_CODE_DIR}/metadef/graph/operator_factory.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" | "${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/tensor.cc" | "${GE_CODE_DIR}/metadef/graph/tensor.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/types.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" | "${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" | "${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/format_refiner.cc" | "${GE_CODE_DIR}/metadef/graph/format_refiner.cc" | ||||
| @@ -115,7 +116,7 @@ target_compile_definitions(ut_libgraph PRIVATE | |||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_link_libraries(ut_libgraph | |||||
| target_link_libraries(ut_libgraph | |||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| gtest | gtest | ||||
| gtest_main | gtest_main | ||||
| @@ -802,6 +802,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
| "session/ge_api_unittest.cc" | "session/ge_api_unittest.cc" | ||||
| "session/inner_session_unittest.cc" | "session/inner_session_unittest.cc" | ||||
| "session/session_manager_unittest.cc" | "session/session_manager_unittest.cc" | ||||
| "common/host_cpu_engine_unittest.cc" | |||||
| ) | ) | ||||
| set(GENERATOR_TEST_FILES | set(GENERATOR_TEST_FILES | ||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "ge_local_engine/engine/host_cpu_engine.h" | |||||
| #undef private | |||||
| #undef protected | |||||
| namespace ge { | |||||
| class UTEST_host_cpu_engine : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| TEST_F(UTEST_host_cpu_engine, PrepareOutputs_success) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>("name", "type"); | |||||
| op_desc->AddOutputDesc("1", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_BOOL)); | |||||
| op_desc->AddOutputDesc("2", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT8)); | |||||
| op_desc->AddOutputDesc("3", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT16)); | |||||
| op_desc->AddOutputDesc("4", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT32)); | |||||
| op_desc->AddOutputDesc("5", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT64)); | |||||
| op_desc->AddOutputDesc("6", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT8)); | |||||
| op_desc->AddOutputDesc("7", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT16)); | |||||
| op_desc->AddOutputDesc("8", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT32)); | |||||
| op_desc->AddOutputDesc("9", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT64)); | |||||
| op_desc->AddOutputDesc("10", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_FLOAT16)); | |||||
| op_desc->AddOutputDesc("11", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_FLOAT)); | |||||
| op_desc->AddOutputDesc("12", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_DOUBLE)); | |||||
| op_desc->AddOutputDesc("13", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT4)); | |||||
| vector<GeTensorPtr> outputs; | |||||
| GeTensorPtr value = std::make_shared<GeTensor>(); | |||||
| for (int32_t i = 0; i < 13; i++) { | |||||
| outputs.push_back(value); | |||||
| } | |||||
| map<std::string, Tensor> named_outputs; | |||||
| auto ret = HostCpuEngine::GetInstance().PrepareOutputs(op_desc, outputs, named_outputs); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| EXPECT_EQ(named_outputs.size(), 13); | |||||
| } | |||||
| TEST_F(UTEST_host_cpu_engine, PrepareOutputs_need_create_success) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>("name", "type"); | |||||
| op_desc->AddOutputDesc("output_1", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT32)); | |||||
| vector<GeTensorPtr> outputs; | |||||
| map<std::string, Tensor> named_outputs; | |||||
| auto ret = HostCpuEngine::GetInstance().PrepareOutputs(op_desc, outputs, named_outputs); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| EXPECT_EQ(named_outputs.size(), 1); | |||||
| EXPECT_EQ(named_outputs["output_1"].GetSize(), 16); | |||||
| EXPECT_EQ(named_outputs["output_1"].GetTensorDesc().GetDataType(), DT_INT32); | |||||
| EXPECT_EQ(named_outputs["output_1"].GetTensorDesc().GetShape().GetShapeSize(), 4); | |||||
| } | |||||
| } // namespace ge | |||||