|
|
@@ -113,6 +113,30 @@ Status UpdateInputsBufferAddr(StreamResource *stream_resource, rtStream_t stream |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status ModifyTensorDesc(GeTensorDesc &tensor) { |
|
|
|
int64_t storage_format_val = static_cast<Format>(FORMAT_RESERVED); |
|
|
|
(void)AttrUtils::GetInt(tensor, ge::ATTR_NAME_STORAGE_FORMAT, storage_format_val); |
|
|
|
auto storage_format = static_cast<Format>(storage_format_val); |
|
|
|
auto format = tensor.GetFormat(); |
|
|
|
if (storage_format != FORMAT_RESERVED && storage_format != format) { |
|
|
|
std::vector<int64_t> storage_shape; |
|
|
|
if (!AttrUtils::GetListInt(tensor, ge::ATTR_NAME_STORAGE_SHAPE, storage_shape)) { |
|
|
|
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Get][storage_shape]failed while storage_format was set."); |
|
|
|
REPORT_INNER_ERROR("E19999", "Get storage_shape failed while storage_format was set."); |
|
|
|
return ACL_ERROR_GE_INTERNAL_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
GELOGD("Storage format set. update shape to [%s], and original shape to [%s]", |
|
|
|
GeShape(storage_shape).ToString().c_str(), tensor.GetShape().ToString().c_str()); |
|
|
|
tensor.SetOriginShape(tensor.GetShape()); |
|
|
|
tensor.SetOriginFormat(format); |
|
|
|
tensor.SetShape(GeShape(storage_shape)); |
|
|
|
tensor.SetFormat(storage_format); |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status InitHybridModelArgs(const std::vector<DataBuffer> &input_buffers, |
|
|
|
const std::vector<DataBuffer> &output_buffers, |
|
|
|
const std::vector<GeTensorDesc> &inputs_desc, |
|
|
@@ -126,6 +150,7 @@ Status InitHybridModelArgs(const std::vector<DataBuffer> &input_buffers, |
|
|
|
for (auto &tensor_desc : inputs_desc) { |
|
|
|
auto desc = MakeShared<GeTensorDesc>(tensor_desc); |
|
|
|
GE_CHECK_NOTNULL(desc); |
|
|
|
GE_CHK_STATUS_RET_NOLOG(ModifyTensorDesc(*desc)); |
|
|
|
args.input_desc.emplace_back(desc); |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|