Browse Source

single op support null tensor

tags/v1.5.1
wangzhengjun 4 years ago
parent
commit
87fd8b270b
4 changed files with 72 additions and 6 deletions
  1. +1
    -2
      ge/graph/manager/util/hcom_util.cc
  2. +12
    -4
      ge/single_op/task/op_task.cc
  3. +11
    -0
      tests/ut/ge/graph/manager/hcom_util_unittest.cc
  4. +48
    -0
      tests/ut/ge/single_op/single_op_task_unittest.cc

+ 1
- 2
ge/graph/manager/util/hcom_util.cc View File

@@ -109,8 +109,7 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, HcclDataType
GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetInputDescPtr(i), input_size),
"[Get][Size] from TensorDesc failed, op:%s, input index:%zu", op_desc->GetName().c_str(), i);
// dynamic shape hccl op get size from output tensor desc
if (op_desc->HasAttr(ATTR_NAME_IS_UNKNOWN_SHAPE)) {
GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i));
if (op_desc->HasAttr(ATTR_NAME_IS_UNKNOWN_SHAPE) && (op_desc->GetOutputDescPtr(i) != nullptr)) {
GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), input_size),
"[Get][Size] from TensorDesc failed, op:%s, input index:%zu", op_desc->GetName().c_str(), i);
}


+ 12
- 4
ge/single_op/task/op_task.cc View File

@@ -746,16 +746,24 @@ Status AiCpuBaseTask::UpdateIoAddr(const vector<DataBuffer> &inputs, const vecto
GE_CHK_BOOL_RET_STATUS(non_const_index < inputs.size(), ACL_ERROR_GE_PARAM_INVALID,
"[Check][Size] Input size is %zu, but get non_const_index is %zu", inputs.size(), non_const_index);
auto addr = inputs[non_const_index].data;
GE_CHECK_NOTNULL(addr);
GELOGD("AICpuTask input[%zu] addr = %p", input_index, addr);
uint64_t length = inputs[non_const_index].length;
if (length != 0 && addr == nullptr) {
GELOGE(PARAM_INVALID, "[Check][Addr]AiCpuTask input[%zu] addr is nullptr, length = %lu", input_index, length);
return PARAM_INVALID;
}
GELOGD("AICpuTask input[%zu] addr = %p, length = %lu.", input_index, addr, length);
*arg_base++ = reinterpret_cast<uintptr_t>(addr);
non_const_index++;
}

for (size_t i = 0; i < outputs.size(); ++i) {
auto addr = outputs[i].data;
GE_CHECK_NOTNULL(addr);
GELOGD("AICpuTask output[%zu] addr = %p", i, addr);
uint64_t length = outputs[i].length;
if (length != 0 && addr == nullptr) {
GELOGE(PARAM_INVALID, "[Check][Addr]AiCpuTask output[%zu] addr is nullptr, length = %lu", i, length);
return PARAM_INVALID;
}
GELOGD("AICpuTask output[%zu] addr = %p, length = %lu.", i, addr, length);
*arg_base++ = reinterpret_cast<uintptr_t>(addr);
}



+ 11
- 0
tests/ut/ge/graph/manager/hcom_util_unittest.cc View File

@@ -94,4 +94,15 @@ TEST_F(UtestHcomUtil, test_GetHcomCount_succ) {
auto ret = hcom_ome_util.GetHcomCount(op_desc, HCCL_DATA_TYPE_FP32, true, count);
EXPECT_EQ(ret, 0);
}

TEST_F(UtestHcomUtil, test_GetHcomCount_succ_2) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr node = NodeBuilder("node", HCOMSEND).AddInputDesc({1, 1, 224, 224}).Build(graph);
auto op_desc = node->GetOpDesc();
HcomOmeUtil hcom_util;
int count = 0;
auto ret = hcom_util.GetHcomCount(op_desc, HCCL_DATA_TYPE_FP32, true, count);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(count, 224 * 224);
}
} // namespace ge

+ 48
- 0
tests/ut/ge/single_op/single_op_task_unittest.cc View File

@@ -189,3 +189,51 @@ TEST_F(UtestSingleOpTask, test_atomic_exec) {
optiling::utils::OpRunInfo run_info(0, true, 0);
task.CalcTilingInfo(run_info);
}

TEST_F(UtestSingleOpTask, test_aicpu_task_update_io_addr) {
AiCpuCCTask task;
task.num_inputs_ = 2;
task.num_outputs_ = 1;
task.input_is_const_ = {true, false};
int total_addr = 3;
uint32_t* addrs[total_addr] = {nullptr, nullptr, nullptr};
task.io_addr_ = reinterpret_cast<uintptr_t*>(addrs);
task.io_addr_num_ = total_addr;

{
vector<DataBuffer> inputs(1, DataBuffer());
vector<DataBuffer> outputs(1, DataBuffer());
auto ret = task.UpdateIoAddr(inputs, outputs);
ASSERT_EQ(ret, SUCCESS);
ASSERT_EQ(addrs[0], nullptr);
ASSERT_EQ(addrs[1], nullptr);
ASSERT_EQ(addrs[2], nullptr);
}

{
uint32_t data_buf[2];
vector<DataBuffer> inputs{DataBuffer(&data_buf[0], 4, false)};
vector<DataBuffer> outputs{DataBuffer(&data_buf[1], 4, false)};
auto ret = task.UpdateIoAddr(inputs, outputs);
ASSERT_EQ(ret, SUCCESS);
ASSERT_EQ(addrs[0], nullptr);
ASSERT_EQ(addrs[1], &data_buf[0]);
ASSERT_EQ(addrs[2], &data_buf[1]);
}

{
uint32_t data_buf[2];
vector<DataBuffer> inputs{DataBuffer(nullptr, 4, false)};
vector<DataBuffer> outputs{DataBuffer(&data_buf[1], 4, false)};
auto ret = task.UpdateIoAddr(inputs, outputs);
ASSERT_EQ(ret, PARAM_INVALID);
}

{
uint32_t data_buf[2];
vector<DataBuffer> inputs{DataBuffer(&data_buf[0], 4, false)};
vector<DataBuffer> outputs{DataBuffer(nullptr, 4, false)};
auto ret = task.UpdateIoAddr(inputs, outputs);
ASSERT_EQ(ret, PARAM_INVALID);
}
}

Loading…
Cancel
Save