Browse Source

!669 change l2 buffer for mult_batch

From: @jiming6
Reviewed-by: @liujunzhu,@xchu42
Signed-off-by: @wqtshg
tags/v1.2.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
5d743d9ea2
4 changed files with 36 additions and 17 deletions
  1. +27
    -10
      ge/graph/build/stream_graph_optimizer.cc
  2. +1
    -1
      ge/graph/build/stream_graph_optimizer.h
  3. +1
    -1
      ge/graph/build/task_generator.cc
  4. +7
    -5
      ge/graph/preprocess/multi_batch_copy_graph.cc

+ 27
- 10
ge/graph/build/stream_graph_optimizer.cc View File

@@ -48,26 +48,41 @@ void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Grap
} }
} }


bool StreamGraphOptimizer::IsSameStreamId(const ComputeGraphPtr &comp_graph) {
bool StreamGraphOptimizer::IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &comp_graph) {
if (comp_graph == nullptr) { if (comp_graph == nullptr) {
return false; return false;
} }
std::set<int64_t> stream_set; std::set<int64_t> stream_set;
std::set<std::string> label_set;
for (const ge::NodePtr &cur_node : comp_graph->GetDirectNode()) { for (const ge::NodePtr &cur_node : comp_graph->GetDirectNode()) {
GE_IF_BOOL_EXEC(cur_node->GetOpDesc() == nullptr, continue); GE_IF_BOOL_EXEC(cur_node->GetOpDesc() == nullptr, continue);
int64_t stream_id = cur_node->GetOpDesc()->GetStreamId(); int64_t stream_id = cur_node->GetOpDesc()->GetStreamId();
if (stream_id == kInvalidStream) { if (stream_id == kInvalidStream) {
continue; continue;
} }
GELOGD("Node %s in subgraph %s stream id is: %ld, node num: %zu", cur_node->GetName().c_str(),
comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize());
stream_set.insert(stream_id); stream_set.insert(stream_id);

std::string batch_label;
if (AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) {
label_set.insert(batch_label);
} else {
GELOGD("Node %s[%s] has no batch label, subgraph %s, stream id: %ld", cur_node->GetName().c_str(),
cur_node->GetType().c_str(), comp_graph->GetName().c_str(), stream_id);
continue;
}

GELOGD("Node %s in subgraph %s stream id: %ld, node num: %zu", cur_node->GetName().c_str(),
comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize());
} }
if (stream_set.size() > 1) {
GELOGI("Nodes of graph: %s have different stream id, node num: %zu, different stream num: %zu.",
if (stream_set.size() > 1 || label_set.size() > 1) {
GELOGI("Nodes of graph: %s have different stream id or batch_label, node num: %zu, different stream num: %zu.",
comp_graph->GetName().c_str(), comp_graph->GetDirectNodesSize(), stream_set.size()); comp_graph->GetName().c_str(), comp_graph->GetDirectNodesSize(), stream_set.size());
return false; return false;
} }

if (!label_set.empty()) {
(void)AttrUtils::SetStr(comp_graph, ATTR_NAME_BATCH_LABEL, *label_set.begin());
}
return true; return true;
} }


@@ -99,8 +114,8 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com
continue; continue;
} }


if (!IsSameStreamId(subgraph)) {
GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str());
if (!IsSameStreamIdOrBatchLabel(subgraph)) {
GELOGI("There are more than one stream or batch_label in subgraph %s", subgraph->GetName().c_str());
continue; continue;
} }
OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); OpDescPtr op_desc = nodes.at(0)->GetOpDesc();
@@ -112,9 +127,11 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com
return FAILED; return FAILED;
} }
run_context.stream = run_context.graphStreamList[stream_id]; run_context.stream = run_context.graphStreamList[stream_id];
GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.",
subgraph->GetName().c_str(), engine_name.c_str(), stream_id,
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream)));
std::string batch_label;
(void)AttrUtils::GetStr(subgraph, ATTR_NAME_BATCH_LABEL, batch_label);
GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu, "
"batch_label: %s", subgraph->GetName().c_str(), engine_name.c_str(), stream_id,
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream)), batch_label.c_str());
for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) {
GE_CHECK_NOTNULL(*iter); GE_CHECK_NOTNULL(*iter);
Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context); Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context);


+ 1
- 1
ge/graph/build/stream_graph_optimizer.h View File

@@ -41,7 +41,7 @@ class StreamGraphOptimizer {
private: private:
void RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map); void RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map);


bool IsSameStreamId(const ComputeGraphPtr &comp_graph);
bool IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &comp_graph);
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_BUILD_OPTIMIZE_STREAM_GRAPH_H_ #endif // GE_GRAPH_BUILD_OPTIMIZE_STREAM_GRAPH_H_

+ 1
- 1
ge/graph/build/task_generator.cc View File

@@ -567,7 +567,7 @@ Status TaskGenerator::MarkFirstAndLastOps(const vector<OpDescPtr> &ops, bool is_
continue; continue;
} }
string op_type = op_desc->GetType(); string op_type = op_desc->GetType();
if (!is_single_stream && (!op_desc->GetSubgraphInstanceNames().empty() || separator_types.count(op_type) != 0)) {
if (!op_desc->GetSubgraphInstanceNames().empty() || separator_types.count(op_type) != 0) {
continuous_op_lists.emplace_back(vector<OpDescPtr>()); continuous_op_lists.emplace_back(vector<OpDescPtr>());
} else { } else {
continuous_op_lists.back().emplace_back(op_desc); continuous_op_lists.back().emplace_back(op_desc);


+ 7
- 5
ge/graph/preprocess/multi_batch_copy_graph.cc View File

@@ -1407,11 +1407,13 @@ Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() {
} }


Status ProcessMultiBatch(ComputeGraphPtr &graph) { Status ProcessMultiBatch(ComputeGraphPtr &graph) {
const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE");
if (multi_batch_with_case != nullptr) {
PassManager pass_manager;
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass));
return pass_manager.Run(graph);
if (GetLocalOmgContext().dynamic_node_type.empty()) {
const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN");
if (multi_batch_with_switchn == nullptr) {
PassManager pass_manager;
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass));
return pass_manager.Run(graph);
}
} }
if (!GetLocalOmgContext().need_multi_batch) { if (!GetLocalOmgContext().need_multi_batch) {
GELOGI("No need to process_multi for no_train graph."); GELOGI("No need to process_multi for no_train graph.");


Loading…
Cancel
Save