diff --git a/ge/graph/build/logical_stream_allocator.cc b/ge/graph/build/logical_stream_allocator.cc index 72337b26..81b2f182 100644 --- a/ge/graph/build/logical_stream_allocator.cc +++ b/ge/graph/build/logical_stream_allocator.cc @@ -462,7 +462,7 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector all_reduce_succs; for (const NodePtr &node : graph->GetDirectNode()) { - if ((node->GetType() != HCOMALLREDUCE && node->GetType() != HVDCALLBACKALLREDUCE) || + if (!IsHcomNode(node->GetType()) || node->GetInDataNodes().size() <= 1) { continue; } @@ -507,7 +507,7 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vectorGetType() != HCOMALLREDUCE && node->GetType() != HVDCALLBACKALLREDUCE)) { + if (!IsHcomNode(node->GetType())) { GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); node->GetOpDesc()->SetStreamId(new_stream); } @@ -517,6 +517,11 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector &scheduler_confs, const map &max_parallel_num) : scheduler_confs_(scheduler_confs), max_parallel_num_(max_parallel_num) {} diff --git a/ge/graph/build/logical_stream_allocator.h b/ge/graph/build/logical_stream_allocator.h index e09d7cd6..0aebb9b4 100644 --- a/ge/graph/build/logical_stream_allocator.h +++ b/ge/graph/build/logical_stream_allocator.h @@ -166,6 +166,8 @@ class AllReduceParallelPass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass); Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; + private: + bool IsHcomNode(const std::string& node_type); }; // Assign logical streams which is not limited by the number of tasks.