Browse Source

!608 Add pytorch input const.

From: @zhao_zhixuan
Reviewed-by: @xchu42,@ji_chen
Signed-off-by: @ji_chen
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
20bf93248f
4 changed files with 33 additions and 3 deletions
  1. +21
    -1
      ge/generator/ge_generator.cc
  2. +10
    -0
      ge/graph/preprocess/graph_preprocess.cc
  3. +1
    -1
      metadef
  4. +1
    -1
      parser

+ 21
- 1
ge/generator/ge_generator.cc View File

@@ -156,7 +156,12 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, GeTen
}

string op_type;
if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
bool is_const = false;
(void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const);
if (is_const) {
GELOGD("Get input[%d] is const", index);
op_type = CONSTANTOP;
} else if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
op_type = DATA;
}

@@ -165,6 +170,18 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, GeTen
if (data_op == nullptr) {
return FAILED;
}
if (is_const) {
ConstGeTensorPtr tensor_value;
if (!AttrUtils::GetTensor(tensor, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
GELOGE(FAILED, "Get value failed, node name:%s.", tensor.GetName().c_str());
return FAILED;
}
if (!AttrUtils::SetTensor(data_op, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS fail.");
return FAILED;
}
}

(void)AttrUtils::SetBool(data_op, "_is_single_op", true);

GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add input desc fail.");
@@ -557,6 +574,9 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor>
Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline) {
if (!is_offline) {
(void)AttrUtils::SetBool(op_desc, ATTR_DYNAMIC_SHAPE_SINGLE_AICPU, true);
}

if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
GELOGE(PARAM_INVALID, "input param is invalid when build single op!");


+ 10
- 0
ge/graph/preprocess/graph_preprocess.cc View File

@@ -1797,6 +1797,16 @@ Status GraphPrepare::PrepareOptimize() {
}

void GraphPrepare::TypeConversionOfConstant() {
bool is_acl_compile = false;
for (ge::NodePtr &n : compute_graph_->GetAllNodes()) {
// This can ensure that n is not a null pointer
// No Conversion when called by aclOpCompile
(void)AttrUtils::GetBool(n->GetOpDesc(), ATTR_DYNAMIC_SHAPE_SINGLE_AICPU, is_acl_compile);
if (is_acl_compile) {
return;
}
}

if (options_.train_graph_flag) {
GELOGD("trans CONSTANT to CONSTANTOP in train.");
for (ge::NodePtr &n : compute_graph_->GetAllNodes()) {


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit bd2cfdfa85a3d9dcbd7dc825f5759c7f8b3ffa9a
Subproject commit c85822cd5404e40cb4ff2bfc9483062648c13c57

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit c78651fee671ac079c56d2c3ff0d0439ea82f2fa
Subproject commit 5bc8c38b37476e8f4b9391c96e4a2cca59e53d8e

Loading…
Cancel
Save