Browse Source

Fix bug of modify output shape to -2.

tags/v1.2.0
unknown 3 years ago
parent
commit
ec5326ca4c
3 changed files with 14 additions and 9 deletions
  1. +9
    -0
      ge/generator/ge_generator.cc
  2. +4
    -8
      ge/graph/passes/dynamic_single_op_reset_shape_pass.cc
  3. +1
    -1
      ge/graph/passes/dynamic_single_op_reset_shape_pass.h

+ 9
- 0
ge/generator/ge_generator.cc View File

@@ -262,6 +262,15 @@ static Status CheckShapeReset(const OpDescPtr &op_desc, bool &change_shape_flag)
change_shape_flag = true;
}
}
for (size_t i = 0; i < op_desc->GetAllOutputsDesc().size(); i++) {
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(i));
GE_CHECK_NOTNULL(output_desc);
// pass scalar output desc
auto dims = output_desc->GetShape().GetDims();
if (dims.size() == kDynamicDimSize && dims[0] == kDynamicDimValue) {
change_shape_flag = true;
}
}
return SUCCESS;
}



+ 4
- 8
ge/graph/passes/dynamic_single_op_reset_shape_pass.cc View File

@@ -113,16 +113,13 @@ Status DynamicSingleOpResetShapePass::ResetOpShape(OpDescPtr &op_desc) {
GE_CHECK_NOTNULL(op_desc);
std::vector<int64_t> dynamic_shape_dims = {kDynamicShapeDim};
GeShape dynamic_shape(dynamic_shape_dims);
bool reset_shape_flag = false;
if (ResetInputTensorShape(op_desc, dynamic_shape, reset_shape_flag) == SUCCESS && reset_shape_flag) {
(void)ResetOutputTensorShape(op_desc, dynamic_shape);
}
(void)ResetInputTensorShape(op_desc, dynamic_shape);
(void)ResetOutputTensorShape(op_desc, dynamic_shape);
return SUCCESS;
}

Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape,
bool &reset_shape_flag) {
reset_shape_flag = false;
Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc,
const GeShape &dynamic_shape) {
GE_CHECK_NOTNULL(op_desc);
for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) {
auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i));
@@ -136,7 +133,6 @@ Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc,
if (CheckIfConstInput(input_desc)) {
continue;
}
reset_shape_flag = true;
input_desc->SetShape(dynamic_shape);
}
return SUCCESS;


+ 1
- 1
ge/graph/passes/dynamic_single_op_reset_shape_pass.h View File

@@ -27,7 +27,7 @@ class DynamicSingleOpResetShapePass : public GraphPass {

private:
Status ResetOpShape(OpDescPtr &op_desc);
Status ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape, bool &reset_shape_flag);
Status ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape);
Status ResetOutputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape);
Status CheckAllAicpuNodes(const ComputeGraphPtr &graph, bool &is_not_aicpu);
bool CheckIfConstInput(const GeTensorDescPtr &input_tensor_desc);


Loading…
Cancel
Save