| @@ -117,7 +117,8 @@ Status ModifyTensorDesc(GeTensorDesc &tensor) { | |||||
| int64_t storage_format_val = static_cast<Format>(FORMAT_RESERVED); | int64_t storage_format_val = static_cast<Format>(FORMAT_RESERVED); | ||||
| (void)AttrUtils::GetInt(tensor, ge::ATTR_NAME_STORAGE_FORMAT, storage_format_val); | (void)AttrUtils::GetInt(tensor, ge::ATTR_NAME_STORAGE_FORMAT, storage_format_val); | ||||
| auto storage_format = static_cast<Format>(storage_format_val); | auto storage_format = static_cast<Format>(storage_format_val); | ||||
| if (storage_format != FORMAT_RESERVED) { | |||||
| auto format = tensor.GetFormat(); | |||||
| if (storage_format != FORMAT_RESERVED && storage_format != format) { | |||||
| std::vector<int64_t> storage_shape; | std::vector<int64_t> storage_shape; | ||||
| if (!AttrUtils::GetListInt(tensor, ge::ATTR_NAME_STORAGE_SHAPE, storage_shape)) { | if (!AttrUtils::GetListInt(tensor, ge::ATTR_NAME_STORAGE_SHAPE, storage_shape)) { | ||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Get][storage_shape]failed while storage_format was set."); | GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Get][storage_shape]failed while storage_format was set."); | ||||
| @@ -127,8 +128,10 @@ Status ModifyTensorDesc(GeTensorDesc &tensor) { | |||||
| GELOGD("Storage format set. update shape to [%s], and original shape to [%s]", | GELOGD("Storage format set. update shape to [%s], and original shape to [%s]", | ||||
| GeShape(storage_shape).ToString().c_str(), tensor.GetShape().ToString().c_str()); | GeShape(storage_shape).ToString().c_str(), tensor.GetShape().ToString().c_str()); | ||||
| tensor.SetShape(GeShape(std::move(storage_shape))); | |||||
| tensor.SetFormat(std::move(storage_format)); | |||||
| tensor.SetOriginShape(tensor.GetShape()); | |||||
| tensor.SetOriginFormat(format); | |||||
| tensor.SetShape(GeShape(storage_shape)); | |||||
| tensor.SetFormat(storage_format); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||