Browse Source

fix storage bug.

tags/v1.2.0
unknown 3 years ago
parent
commit
b0060848d6
1 changed files with 10 additions and 3 deletions
  1. +10
    -3
      ge/generator/ge_generator.cc

+ 10
- 3
ge/generator/ge_generator.cc View File

@@ -265,7 +265,7 @@ static Status CheckShapeReset(const OpDescPtr &op_desc, bool &change_shape_flag)
return SUCCESS; return SUCCESS;
} }


static void ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) {
static Status ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) {
for (auto input : inputs) { for (auto input : inputs) {
auto input_desc = input.GetTensorDesc(); auto input_desc = input.GetTensorDesc();
GeShape shape_ori = input_desc.GetShape(); GeShape shape_ori = input_desc.GetShape();
@@ -280,6 +280,12 @@ static void ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor>
bool is_const = false; bool is_const = false;
(void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const); (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
if (!is_const && shape_ori.GetDims().size() > 0) { if (!is_const && shape_ori.GetDims().size() > 0) {
int64_t storage_format = FORMAT_NCHW;
if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_STORAGE_FORMAT, storage_format) &&
!ge::AttrUtils::SetListInt(desc, ge::ATTR_NAME_STORAGE_SHAPE, dynamic_shape_dims)) {
GELOGE(FAILED, "Set attr ATTR_NAME_STORAGE_SHAPE fail.");
return FAILED;
}
desc.SetShape(dynamic_shape); desc.SetShape(dynamic_shape);
desc.SetShapeRange(dynamic_shape_range); desc.SetShapeRange(dynamic_shape_range);
} }
@@ -287,6 +293,7 @@ static void ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor>
inputTensor.SetTensorDesc(desc); inputTensor.SetTensorDesc(desc);
inputs_dynamic.push_back(inputTensor); inputs_dynamic.push_back(inputTensor);
} }
return SUCCESS;
} }


class GeGenerator::Impl { class GeGenerator::Impl {
@@ -684,8 +691,8 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
if (CheckShapeReset(op_desc, dynamic_flag) == SUCCESS && dynamic_flag) { if (CheckShapeReset(op_desc, dynamic_flag) == SUCCESS && dynamic_flag) {
vector<GeTensor> inputs_dynamic; vector<GeTensor> inputs_dynamic;
vector<GeTensor> outputs_dynamic; vector<GeTensor> outputs_dynamic;
ResetTensorVecShape(inputs, inputs_dynamic);
ResetTensorVecShape(outputs, outputs_dynamic);
GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(inputs, inputs_dynamic));
GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(outputs, outputs_dynamic));
GE_CHK_STATUS_RET_NOLOG( GE_CHK_STATUS_RET_NOLOG(
impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic)); impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic));
} else { } else {


Loading…
Cancel
Save