From 67974b31362c13d8fa986abc7ecdee3f9e50b2f4 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Sat, 17 Jul 2021 14:41:15 +0800 Subject: [PATCH] fix pytorch infershape origin shape --- ge/graph/passes/infershape_pass.cc | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 05b1b5fc..0555929d 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -228,19 +228,13 @@ bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDe } graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { - changed = !SameTensorDesc(src, dst); - // refresh src itself - src->SetOriginShape(src->GetShape()); - src->SetOriginDataType(src->GetDataType()); - TensorUtils::SetRealDimCnt(*src, static_cast(src->GetOriginShape().GetDims().size())); - vector> src_shape_range; - src->GetShapeRange(src_shape_range); - src->SetOriginShapeRange(src_shape_range); - - if (!changed) { + changed = false; + if (SameTensorDesc(src, dst)) { GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); return SUCCESS; } + + changed = true; UpdateShapeAndDType(src, dst); GELOGD( "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s."