diff --git a/ge/graph/build/stream_graph_optimizer.cc b/ge/graph/build/stream_graph_optimizer.cc index 2933d413..05049818 100644 --- a/ge/graph/build/stream_graph_optimizer.cc +++ b/ge/graph/build/stream_graph_optimizer.cc @@ -48,26 +48,41 @@ void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Grap } } -bool StreamGraphOptimizer::IsSameStreamId(const ComputeGraphPtr &comp_graph) { +bool StreamGraphOptimizer::IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &comp_graph) { if (comp_graph == nullptr) { return false; } std::set stream_set; + std::set label_set; for (const ge::NodePtr &cur_node : comp_graph->GetDirectNode()) { GE_IF_BOOL_EXEC(cur_node->GetOpDesc() == nullptr, continue); int64_t stream_id = cur_node->GetOpDesc()->GetStreamId(); if (stream_id == kInvalidStream) { continue; } - GELOGD("Node %s in subgraph %s stream id is: %ld, node num: %zu", cur_node->GetName().c_str(), - comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize()); stream_set.insert(stream_id); + + std::string batch_label; + if (AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { + label_set.insert(batch_label); + } else { + GELOGD("Node %s[%s] has no batch label, subgraph %s, stream id: %ld", cur_node->GetName().c_str(), + cur_node->GetType().c_str(), comp_graph->GetName().c_str(), stream_id); + continue; + } + + GELOGD("Node %s in subgraph %s stream id: %ld, node num: %zu", cur_node->GetName().c_str(), + comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize()); } - if (stream_set.size() > 1) { - GELOGI("Nodes of graph: %s have different stream id, node num: %zu, different stream num: %zu.", + if (stream_set.size() > 1 || label_set.size() > 1) { + GELOGI("Nodes of graph: %s have different stream id or batch_label, node num: %zu, different stream num: %zu.", comp_graph->GetName().c_str(), comp_graph->GetDirectNodesSize(), stream_set.size()); return false; } + + if (!label_set.empty()) { + (void)AttrUtils::SetStr(comp_graph, ATTR_NAME_BATCH_LABEL, *label_set.begin()); + } return true; } @@ -99,8 +114,8 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com continue; } - if (!IsSameStreamId(subgraph)) { - GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str()); + if (!IsSameStreamIdOrBatchLabel(subgraph)) { + GELOGI("There are more than one stream or batch_label in subgraph %s", subgraph->GetName().c_str()); continue; } OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); @@ -112,9 +127,11 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com return FAILED; } run_context.stream = run_context.graphStreamList[stream_id]; - GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.", - subgraph->GetName().c_str(), engine_name.c_str(), stream_id, - static_cast(reinterpret_cast(run_context.stream))); + std::string batch_label; + (void)AttrUtils::GetStr(subgraph, ATTR_NAME_BATCH_LABEL, batch_label); + GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu, " + "batch_label: %s", subgraph->GetName().c_str(), engine_name.c_str(), stream_id, + static_cast(reinterpret_cast(run_context.stream)), batch_label.c_str()); for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { GE_CHECK_NOTNULL(*iter); Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context); diff --git a/ge/graph/build/stream_graph_optimizer.h b/ge/graph/build/stream_graph_optimizer.h index b0eea135..d69fa7ba 100644 --- a/ge/graph/build/stream_graph_optimizer.h +++ b/ge/graph/build/stream_graph_optimizer.h @@ -41,7 +41,7 @@ class StreamGraphOptimizer { private: void RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map); - bool IsSameStreamId(const ComputeGraphPtr &comp_graph); + bool IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &comp_graph); }; } // namespace ge #endif // GE_GRAPH_BUILD_OPTIMIZE_STREAM_GRAPH_H_ diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc index b506f945..2089ad31 100755 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -567,7 +567,7 @@ Status TaskGenerator::MarkFirstAndLastOps(const vector &ops, bool is_ continue; } string op_type = op_desc->GetType(); - if (!is_single_stream && (!op_desc->GetSubgraphInstanceNames().empty() || separator_types.count(op_type) != 0)) { + if (!op_desc->GetSubgraphInstanceNames().empty() || separator_types.count(op_type) != 0) { continuous_op_lists.emplace_back(vector()); } else { continuous_op_lists.back().emplace_back(op_desc); diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index 9ab74d70..a90f145e 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -1407,11 +1407,13 @@ Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() { } Status ProcessMultiBatch(ComputeGraphPtr &graph) { - const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE"); - if (multi_batch_with_case != nullptr) { - PassManager pass_manager; - GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); - return pass_manager.Run(graph); + if (GetLocalOmgContext().dynamic_node_type.empty()) { + const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); + if (multi_batch_with_switchn == nullptr) { + PassManager pass_manager; + GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); + return pass_manager.Run(graph); + } } if (!GetLocalOmgContext().need_multi_batch) { GELOGI("No need to process_multi for no_train graph.");