diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index d162d58e..016f9ef2 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -262,6 +262,15 @@ static Status CheckShapeReset(const OpDescPtr &op_desc, bool &change_shape_flag) change_shape_flag = true; } } + for (size_t i = 0; i < op_desc->GetAllOutputsDesc().size(); i++) { + auto output_desc = op_desc->MutableOutputDesc(static_cast(i)); + GE_CHECK_NOTNULL(output_desc); + // pass scalar output desc + auto dims = output_desc->GetShape().GetDims(); + if (dims.size() == kDynamicDimSize && dims[0] == kDynamicDimValue) { + change_shape_flag = true; + } + } return SUCCESS; } diff --git a/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc b/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc index d50b6df9..6fa63642 100644 --- a/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc +++ b/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc @@ -113,16 +113,13 @@ Status DynamicSingleOpResetShapePass::ResetOpShape(OpDescPtr &op_desc) { GE_CHECK_NOTNULL(op_desc); std::vector dynamic_shape_dims = {kDynamicShapeDim}; GeShape dynamic_shape(dynamic_shape_dims); - bool reset_shape_flag = false; - if (ResetInputTensorShape(op_desc, dynamic_shape, reset_shape_flag) == SUCCESS && reset_shape_flag) { - (void)ResetOutputTensorShape(op_desc, dynamic_shape); - } + (void)ResetInputTensorShape(op_desc, dynamic_shape); + (void)ResetOutputTensorShape(op_desc, dynamic_shape); return SUCCESS; } -Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape, - bool &reset_shape_flag) { - reset_shape_flag = false; +Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc, + const GeShape &dynamic_shape) { GE_CHECK_NOTNULL(op_desc); for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) { auto input_desc = op_desc->MutableInputDesc(static_cast(i)); @@ -136,7 +133,6 @@ Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc, if (CheckIfConstInput(input_desc)) { continue; } - reset_shape_flag = true; input_desc->SetShape(dynamic_shape); } return SUCCESS; diff --git a/ge/graph/passes/dynamic_single_op_reset_shape_pass.h b/ge/graph/passes/dynamic_single_op_reset_shape_pass.h index 897fcac6..765448ff 100644 --- a/ge/graph/passes/dynamic_single_op_reset_shape_pass.h +++ b/ge/graph/passes/dynamic_single_op_reset_shape_pass.h @@ -27,7 +27,7 @@ class DynamicSingleOpResetShapePass : public GraphPass { private: Status ResetOpShape(OpDescPtr &op_desc); - Status ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape, bool &reset_shape_flag); + Status ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape); Status ResetOutputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape); Status CheckAllAicpuNodes(const ComputeGraphPtr &graph, bool &is_not_aicpu); bool CheckIfConstInput(const GeTensorDescPtr &input_tensor_desc);