Browse Source

fix variable fusion when variable only in subgraph

tags/v1.3.0
chuxing 3 years ago
parent
commit
b153ca0391
2 changed files with 4 additions and 3 deletions
  1. +1
    -1
      ge/graph/load/model_manager/davinci_model.cc
  2. +3
    -2
      ge/graph/passes/variable_op_pass.cc

+ 1
- 1
ge/graph/load/model_manager/davinci_model.cc View File

@@ -3904,7 +3904,7 @@ Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id)
}

std::vector<NodePtr> variable_node_list;
for (ge::NodePtr &node : graph->GetDirectNode()) {
for (ge::NodePtr &node : graph->GetAllNodes()) {
if (node == nullptr) {
continue;
}


+ 3
- 2
ge/graph/passes/variable_op_pass.cc View File

@@ -119,8 +119,9 @@ Status VariableOpPass::Run(ge::ComputeGraphPtr graph) {
return INTERNAL_ERROR;
}

auto graph_id = GraphUtils::FindRootGraph(graph)->GetGraphID();
GELOGD("Begin to run variable op pass on graph %s, session %lu, graph id %u", graph->GetName().c_str(),
GetContext().SessionId(), graph->GetGraphID());
GetContext().SessionId(), graph_id);

if (var_accelerate_ctrl_ == nullptr) {
GELOGE(INTERNAL_ERROR, "Failed to run var op pass, the variable accelerate control is null");
@@ -176,7 +177,7 @@ Status VariableOpPass::Run(ge::ComputeGraphPtr graph) {
GELOGE(INTERNAL_ERROR, "Failed to update the format fusion road for var %s", node->GetName().c_str());
return INTERNAL_ERROR;
}
ret = VarManager::Instance(graph->GetSessionID())->SetChangedGraphId(node->GetName(), graph->GetGraphID());
ret = VarManager::Instance(graph->GetSessionID())->SetChangedGraphId(node->GetName(), graph_id);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to update the graph id for var %s", node->GetName().c_str());
return INTERNAL_ERROR;


Loading…
Cancel
Save