Browse Source

fix pytorch infershape origin shape

tags/v1.5.1
zhaoxinxin 3 years ago
parent
commit
67974b3136
1 changed files with 4 additions and 10 deletions
  1. +4
    -10
      ge/graph/passes/infershape_pass.cc

+ 4
- 10
ge/graph/passes/infershape_pass.cc View File

@@ -228,19 +228,13 @@ bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDe
}
graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
changed = !SameTensorDesc(src, dst);
// refresh src itself
src->SetOriginShape(src->GetShape());
src->SetOriginDataType(src->GetDataType());
TensorUtils::SetRealDimCnt(*src, static_cast<uint32_t>(src->GetOriginShape().GetDims().size()));
vector<pair<int64_t, int64_t>> src_shape_range;
src->GetShapeRange(src_shape_range);
src->SetOriginShapeRange(src_shape_range);
if (!changed) {
changed = false;
if (SameTensorDesc(src, dst)) {
GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update.");
return SUCCESS;
}
changed = true;
UpdateShapeAndDType(src, dst);
GELOGD(
"UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s."


Loading…
Cancel
Save