| @@ -228,13 +228,19 @@ bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDe | |||||
| } | } | ||||
| graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { | graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { | ||||
| changed = false; | |||||
| if (SameTensorDesc(src, dst)) { | |||||
| changed = !SameTensorDesc(src, dst); | |||||
| // refresh src itself | |||||
| src->SetOriginShape(src->GetShape()); | |||||
| src->SetOriginDataType(src->GetDataType()); | |||||
| TensorUtils::SetRealDimCnt(*src, static_cast<uint32_t>(src->GetOriginShape().GetDims().size())); | |||||
| vector<pair<int64_t, int64_t>> src_shape_range; | |||||
| src->GetShapeRange(src_shape_range); | |||||
| src->SetOriginShapeRange(src_shape_range); | |||||
| if (!changed) { | |||||
| GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); | GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| changed = true; | |||||
| UpdateShapeAndDType(src, dst); | UpdateShapeAndDType(src, dst); | ||||
| GELOGD( | GELOGD( | ||||
| "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s." | "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s." | ||||