@@ -26,12 +26,15 @@ | |||||
#include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
namespace { | namespace { | ||||
#define CREATE_OUTPUT_CASE(DTYPE, TYPE) \ | |||||
#define CREATE_OUTPUT_CASE(DTYPE) \ | |||||
case (DTYPE): { \ | case (DTYPE): { \ | ||||
GeTensorPtr ge_tensor = nullptr; \ | GeTensorPtr ge_tensor = nullptr; \ | ||||
if (need_create_flag) { \ | if (need_create_flag) { \ | ||||
uint64_t size = data_num * sizeof(TYPE); \ | |||||
ge_tensor = MakeShared<GeTensor>(out_desc, size); \ | |||||
int64_t size = ge::GetSizeInBytes(static_cast<int64_t>(data_num), DTYPE); \ | |||||
if (size < 0) { \ | |||||
return INTERNAL_ERROR; \ | |||||
} \ | |||||
ge_tensor = MakeShared<GeTensor>(out_desc, static_cast<size_t>(size)); \ | |||||
GE_CHECK_NOTNULL(ge_tensor); \ | GE_CHECK_NOTNULL(ge_tensor); \ | ||||
GELOGD("node:%s allocate output %zu success, size=%ld", op_desc->GetName().c_str(), i, size); \ | GELOGD("node:%s allocate output %zu success, size=%ld", op_desc->GetName().c_str(), i, size); \ | ||||
ge_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); \ | ge_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); \ | ||||
@@ -180,18 +183,19 @@ Status HostCpuEngine::PrepareOutputs(const ge::ConstOpDescPtr &op_desc, | |||||
} | } | ||||
} | } | ||||
switch (out_desc.GetDataType()) { | switch (out_desc.GetDataType()) { | ||||
CREATE_OUTPUT_CASE(DT_BOOL, bool) | |||||
CREATE_OUTPUT_CASE(DT_INT8, int8_t) | |||||
CREATE_OUTPUT_CASE(DT_INT16, int16_t) | |||||
CREATE_OUTPUT_CASE(DT_INT32, int32_t) | |||||
CREATE_OUTPUT_CASE(DT_INT64, int64_t) | |||||
CREATE_OUTPUT_CASE(DT_UINT8, uint8_t) | |||||
CREATE_OUTPUT_CASE(DT_UINT16, uint16_t) | |||||
CREATE_OUTPUT_CASE(DT_UINT32, uint32_t) | |||||
CREATE_OUTPUT_CASE(DT_UINT64, uint64_t) | |||||
CREATE_OUTPUT_CASE(DT_FLOAT16, fp16_t) | |||||
CREATE_OUTPUT_CASE(DT_FLOAT, float) | |||||
CREATE_OUTPUT_CASE(DT_DOUBLE, double) | |||||
CREATE_OUTPUT_CASE(DT_BOOL) | |||||
CREATE_OUTPUT_CASE(DT_INT8) | |||||
CREATE_OUTPUT_CASE(DT_INT16) | |||||
CREATE_OUTPUT_CASE(DT_INT32) | |||||
CREATE_OUTPUT_CASE(DT_INT64) | |||||
CREATE_OUTPUT_CASE(DT_UINT8) | |||||
CREATE_OUTPUT_CASE(DT_UINT16) | |||||
CREATE_OUTPUT_CASE(DT_UINT32) | |||||
CREATE_OUTPUT_CASE(DT_UINT64) | |||||
CREATE_OUTPUT_CASE(DT_FLOAT16) | |||||
CREATE_OUTPUT_CASE(DT_FLOAT) | |||||
CREATE_OUTPUT_CASE(DT_DOUBLE) | |||||
CREATE_OUTPUT_CASE(DT_INT4) | |||||
default: | default: | ||||
GELOGW("data type %s not support.", | GELOGW("data type %s not support.", | ||||
TypeUtils::DataTypeToSerialString(out_desc.GetDataType()).c_str()); | TypeUtils::DataTypeToSerialString(out_desc.GetDataType()).c_str()); | ||||
@@ -1 +1 @@ | |||||
Subproject commit 23718da69af64f8a57051ee64d5515ae1e103c70 | |||||
Subproject commit 567381faaff179106abafb264ba696f45c4d2b43 |
@@ -1 +1 @@ | |||||
Subproject commit 9bb03f21773f028b07d5a912db6f176268962c7d | |||||
Subproject commit 9226f9532a3884490b03e48df5d7aa02611e21f4 |
@@ -82,6 +82,7 @@ set(SRC_FILES | |||||
"${GE_CODE_DIR}/metadef/graph/operator_factory.cc" | "${GE_CODE_DIR}/metadef/graph/operator_factory.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" | "${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/tensor.cc" | "${GE_CODE_DIR}/metadef/graph/tensor.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/types.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" | "${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" | "${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/format_refiner.cc" | "${GE_CODE_DIR}/metadef/graph/format_refiner.cc" | ||||
@@ -115,7 +116,7 @@ target_compile_definitions(ut_libgraph PRIVATE | |||||
google=ascend_private | google=ascend_private | ||||
) | ) | ||||
target_link_libraries(ut_libgraph | |||||
target_link_libraries(ut_libgraph | |||||
$<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
gtest | gtest | ||||
gtest_main | gtest_main | ||||
@@ -802,6 +802,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
"session/ge_api_unittest.cc" | "session/ge_api_unittest.cc" | ||||
"session/inner_session_unittest.cc" | "session/inner_session_unittest.cc" | ||||
"session/session_manager_unittest.cc" | "session/session_manager_unittest.cc" | ||||
"common/host_cpu_engine_unittest.cc" | |||||
) | ) | ||||
set(GENERATOR_TEST_FILES | set(GENERATOR_TEST_FILES | ||||
@@ -0,0 +1,73 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#define protected public | |||||
#define private public | |||||
#include "ge_local_engine/engine/host_cpu_engine.h" | |||||
#undef private | |||||
#undef protected | |||||
namespace ge { | |||||
class UTEST_host_cpu_engine : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
TEST_F(UTEST_host_cpu_engine, PrepareOutputs_success) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>("name", "type"); | |||||
op_desc->AddOutputDesc("1", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_BOOL)); | |||||
op_desc->AddOutputDesc("2", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT8)); | |||||
op_desc->AddOutputDesc("3", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT16)); | |||||
op_desc->AddOutputDesc("4", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT32)); | |||||
op_desc->AddOutputDesc("5", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT64)); | |||||
op_desc->AddOutputDesc("6", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT8)); | |||||
op_desc->AddOutputDesc("7", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT16)); | |||||
op_desc->AddOutputDesc("8", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT32)); | |||||
op_desc->AddOutputDesc("9", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_UINT64)); | |||||
op_desc->AddOutputDesc("10", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_FLOAT16)); | |||||
op_desc->AddOutputDesc("11", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_FLOAT)); | |||||
op_desc->AddOutputDesc("12", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_DOUBLE)); | |||||
op_desc->AddOutputDesc("13", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT4)); | |||||
vector<GeTensorPtr> outputs; | |||||
GeTensorPtr value = std::make_shared<GeTensor>(); | |||||
for (int32_t i = 0; i < 13; i++) { | |||||
outputs.push_back(value); | |||||
} | |||||
map<std::string, Tensor> named_outputs; | |||||
auto ret = HostCpuEngine::GetInstance().PrepareOutputs(op_desc, outputs, named_outputs); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
EXPECT_EQ(named_outputs.size(), 13); | |||||
} | |||||
TEST_F(UTEST_host_cpu_engine, PrepareOutputs_need_create_success) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>("name", "type"); | |||||
op_desc->AddOutputDesc("output_1", GeTensorDesc(GeShape({2, 2}), FORMAT_NCHW, DT_INT32)); | |||||
vector<GeTensorPtr> outputs; | |||||
map<std::string, Tensor> named_outputs; | |||||
auto ret = HostCpuEngine::GetInstance().PrepareOutputs(op_desc, outputs, named_outputs); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
EXPECT_EQ(named_outputs.size(), 1); | |||||
EXPECT_EQ(named_outputs["output_1"].GetSize(), 16); | |||||
EXPECT_EQ(named_outputs["output_1"].GetTensorDesc().GetDataType(), DT_INT32); | |||||
EXPECT_EQ(named_outputs["output_1"].GetTensorDesc().GetShape().GetShapeSize(), 4); | |||||
} | |||||
} // namespace ge |