Browse Source

!1213 fix unknonwn shape calc memory bug

From: @wan_xuelei
Reviewed-by: @xchu42,@wqtshg
Signed-off-by:
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
5c18e5df78
5 changed files with 26 additions and 22 deletions
  1. +3
    -1
      ge/graph/build/model_builder.cc
  2. +2
    -2
      ge/graph/passes/dimension_adjust_pass.cc
  3. +11
    -11
      ge/graph/passes/flow_ctrl_pass.cc
  4. +5
    -4
      ge/graph/passes/resource_pair_add_control_pass.cc
  5. +5
    -4
      ge/graph/passes/resource_pair_remove_control_pass.cc

+ 3
- 1
ge/graph/build/model_builder.cc View File

@@ -261,7 +261,9 @@ Status ModelBuilder::SetInputOutputDesc() {
GE_IF_BOOL_EXEC(n->GetInAllNodes().empty() && n->GetOutAllNodes().empty(), continue;);

SetInputIsConst(n);
if (IsGeLocalOp(n->GetOpDesc())) {
bool is_unknow = false;
(void)NodeUtils::GetNodeUnknownShapeStatus(*n, is_unknow);
if ((IsGeLocalOp(n->GetOpDesc())) && (!is_unknow)) {
GE_CHK_STATUS_RET(CalcOutputSize(n), "Calculate output size failed");
}
ret = AdjustConstWeightSize(n, weight_offset_);


+ 2
- 2
ge/graph/passes/dimension_adjust_pass.cc View File

@@ -29,13 +29,13 @@ const int kRemoveInputIndex = 1;

Status DimensionAdjustPass::Run(ge::NodePtr &node) {
if (node == nullptr) {
GELOGE(PARAM_INVALID, "node is nullptr");
GELOGE(PARAM_INVALID, "node is nullptr.");
return PARAM_INVALID;
}

OpDescPtr op_desc_ptr = node->GetOpDesc();
if (op_desc_ptr == nullptr) {
GELOGE(PARAM_INVALID, "GetOpDesc return nullptr");
GELOGE(PARAM_INVALID, "GetOpDesc return nullptr.");
return PARAM_INVALID;
}



+ 11
- 11
ge/graph/passes/flow_ctrl_pass.cc View File

@@ -33,11 +33,11 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) {
GE_CHECK_NOTNULL(compute_graph);

if (!PassUtils::IsNeedTrainIteFlowCtrl(compute_graph)) {
GELOGI("No need FlowCtrl for graph %u", compute_graph->GetGraphID());
GELOGI("No need FlowCtrl for graph %u.", compute_graph->GetGraphID());
return NOT_CHANGED;
}

GELOGI("FlowCtrl pass begin.graph is [%s]", compute_graph->GetName().c_str());
GELOGI("FlowCtrl pass begin.graph is [%s].", compute_graph->GetName().c_str());
bool graph_change = false;
// 1. Add FP/BP flow ctrl (big cycle)
for (auto &node : compute_graph->GetDirectNode()) {
@@ -347,11 +347,11 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c
NodePtr assign_node =
InsertAssignOp(compute_graph, ASSIGN, NODE_NAME_FLOWCTRL_LOOP_ASSIGN, loop_cond_node, loop_reset_node);
if (assign_node == nullptr || switch_node == nullptr) {
GELOGE(PARAM_INVALID, "assign_node or switch node is null");
GELOGE(PARAM_INVALID, "assign_node or switch node is null.");
return FAILED;
}

GE_CHK_STATUS_RET(SetStreamLabel(assign_node, switch_node->GetName()), "set stream label failed");
GE_CHK_STATUS_RET(SetStreamLabel(assign_node, switch_node->GetName()), "set stream label failed.");

graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), assign_node->GetInControlAnchor());
if (add_ret != GRAPH_SUCCESS) {
@@ -370,7 +370,7 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c
}
GE_CHK_STATUS_RET(SetStreamLabel(active_node, switch_node->GetName()), "set stream label failed");
GE_CHK_STATUS_RET(SetSwitchBranchNodeLabel(active_node, switch_node->GetName()),
"set switch branch node label failed");
"set switch branch node label failed.");

string model_exit_name = switch_node->GetName() + "_ModelExit";
GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { model_exit_name }), "set active label list failed");
@@ -401,7 +401,7 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c
}

Status FlowCtrlPass::AddFpBpIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &pre_node) {
GE_IF_BOOL_EXEC(pre_node == nullptr, DOMI_LOGE("pre_node is nullptr"); return FAILED);
GE_IF_BOOL_EXEC(pre_node == nullptr, DOMI_LOGE("pre_node is nullptr."); return FAILED);
string pre_node_name = pre_node->GetName();
GELOGI("Add FpBp Iterator ctrl, pre node:%s.", pre_node_name.c_str());
// 1. Get or add variables
@@ -477,7 +477,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph,
* itersPerLoop loopCond
*/
GE_IF_BOOL_EXEC(loop_after_node == nullptr || compute_graph == nullptr,
DOMI_LOGE("loop after node or compute graph is null"); return FAILED);
DOMI_LOGE("loop after node or compute graph is null."); return FAILED);
InDataAnchorPtr in_anchor = loop_after_node->GetInDataAnchor(0);
if (in_anchor == nullptr || in_anchor->GetPeerOutAnchor() == nullptr) {
GELOGE(FAILED, "Find %s in data anchor failed.", loop_after_node->GetName().c_str());
@@ -498,7 +498,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph,
}

// 2. Add StreamSwitch and edges to switch_node.
GE_IF_BOOL_EXEC(loop_pre_node == nullptr, DOMI_LOGE("loop pre node is null"); return FAILED);
GE_IF_BOOL_EXEC(loop_pre_node == nullptr, DOMI_LOGE("loop pre node is null."); return FAILED);
string switch_name = loop_pre_node->GetName() + "_" + NODE_NAME_STREAM_SWITCH;
NodePtr switch_node = InsertStreamSwitchOp(compute_graph, switch_name, loop_cond_node, iter_per_loop_node);
if (switch_node == nullptr) {
@@ -506,7 +506,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph,
return FAILED;
}

GE_CHK_STATUS_RET(SetStreamLabel(switch_node, switch_name), "set stream label failed");
GE_CHK_STATUS_RET(SetStreamLabel(switch_node, switch_name), "set stream label failed.");

graphStatus add_ret = GraphUtils::AddEdge(loop_pre_node->GetOutControlAnchor(), switch_node->GetInControlAnchor());
if (add_ret != GRAPH_SUCCESS) {
@@ -529,7 +529,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph,
return FAILED;
}

GE_CHK_STATUS_RET(SetStreamLabel(active_node, active_name), "set stream label failed");
GE_CHK_STATUS_RET(SetStreamLabel(active_node, active_name), "set stream label failed.");

GE_IF_BOOL_EXEC(!AttrUtils::SetBool(active_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, true),
DOMI_LOGE("set ATTR_NAME_IS_LOOP_ACTIVE failed"); return FAILED);
@@ -542,7 +542,7 @@ Status FlowCtrlPass::AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph,
}

// used for stream assign to find true branch
GE_CHK_STATUS_RET(SetActiveLabelList(switch_node, { active_name }), "set active label list failed");
GE_CHK_STATUS_RET(SetActiveLabelList(switch_node, { active_name }), "set active label list failed.");
// used for stream assign to find active stream
GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { loop_pre_node->GetName() }), "set active label list failed");
active_nodes_in_iter_loop_.push_back(active_node);


+ 5
- 4
ge/graph/passes/resource_pair_add_control_pass.cc View File

@@ -63,16 +63,17 @@ Status ResourcePairAddControlPass::Run(ComputeGraphPtr graph) {
NodePtr from_node = prefix_2_node.second;
GE_CHECK_NOTNULL(from_node);
auto to_item_prefix_2_node = prefix_2_node_per_type.find(resource_type_pair.second);
// stackpush and stackpop may exist in two subgraphs, no necessary to report error
if (to_item_prefix_2_node == prefix_2_node_per_type.end()) {
GELOGE(PARAM_INVALID, "find peer type node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(),
GELOGW("find peer type node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(),
resource_type_pair.first.c_str(), resource_type_pair.second.c_str());
return PARAM_INVALID;
continue;
}
auto to_prefix_2_node = to_item_prefix_2_node->second.find(prefix);
if (to_prefix_2_node == to_item_prefix_2_node->second.end()) {
GELOGE(PARAM_INVALID, "find peer prefix node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(),
GELOGW("find peer prefix node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(),
resource_type_pair.first.c_str(), resource_type_pair.second.c_str());
return PARAM_INVALID;
continue;
}
NodePtr to_node = to_prefix_2_node->second;
GE_CHECK_NOTNULL(to_node);


+ 5
- 4
ge/graph/passes/resource_pair_remove_control_pass.cc View File

@@ -63,16 +63,17 @@ Status ResourcePairRemoveControlPass::Run(ComputeGraphPtr graph) {
NodePtr from_node = prefix_2_node.second;
GE_CHECK_NOTNULL(from_node);
auto to_item_prefix_2_node = prefix_2_node_per_type.find(resource_type_pair.second);
// stackpush and stackpop may exist in two subgraphs, no necessary to report error
if (to_item_prefix_2_node == prefix_2_node_per_type.end()) {
GELOGE(INTERNAL_ERROR, "find peer type node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(),
GELOGW("find peer type node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(),
resource_type_pair.first.c_str(), resource_type_pair.second.c_str());
return domi::PARAM_INVALID;
continue;
}
auto to_prefix_2_node = to_item_prefix_2_node->second.find(prefix);
if (to_prefix_2_node == to_item_prefix_2_node->second.end()) {
GELOGE(INTERNAL_ERROR, "find peer prefix node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(),
GELOGW("find peer prefix node fail, suffix:%s, from_type:%s, to_type:%s", prefix.c_str(),
resource_type_pair.first.c_str(), resource_type_pair.second.c_str());
return domi::PARAM_INVALID;
continue;
}
NodePtr to_node = to_prefix_2_node->second;
GE_CHECK_NOTNULL(to_node);


Loading…
Cancel
Save