|
|
@@ -117,7 +117,8 @@ Status ModifyTensorDesc(GeTensorDesc &tensor) { |
|
|
|
int64_t storage_format_val = static_cast<Format>(FORMAT_RESERVED); |
|
|
|
(void)AttrUtils::GetInt(tensor, ge::ATTR_NAME_STORAGE_FORMAT, storage_format_val); |
|
|
|
auto storage_format = static_cast<Format>(storage_format_val); |
|
|
|
if (storage_format != FORMAT_RESERVED) { |
|
|
|
auto format = tensor.GetFormat(); |
|
|
|
if (storage_format != FORMAT_RESERVED && storage_format != format) { |
|
|
|
std::vector<int64_t> storage_shape; |
|
|
|
if (!AttrUtils::GetListInt(tensor, ge::ATTR_NAME_STORAGE_SHAPE, storage_shape)) { |
|
|
|
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Get][storage_shape]failed while storage_format was set."); |
|
|
@@ -127,8 +128,10 @@ Status ModifyTensorDesc(GeTensorDesc &tensor) { |
|
|
|
|
|
|
|
GELOGD("Storage format set. update shape to [%s], and original shape to [%s]", |
|
|
|
GeShape(storage_shape).ToString().c_str(), tensor.GetShape().ToString().c_str()); |
|
|
|
tensor.SetShape(GeShape(std::move(storage_shape))); |
|
|
|
tensor.SetFormat(std::move(storage_format)); |
|
|
|
tensor.SetOriginShape(tensor.GetShape()); |
|
|
|
tensor.SetOriginFormat(format); |
|
|
|
tensor.SetShape(GeShape(storage_shape)); |
|
|
|
tensor.SetFormat(storage_format); |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|