diff --git a/ge/graph/build/logical_stream_allocator.cc b/ge/graph/build/logical_stream_allocator.cc index 5c8bc46c..df65f0a9 100644 --- a/ge/graph/build/logical_stream_allocator.cc +++ b/ge/graph/build/logical_stream_allocator.cc @@ -363,13 +363,10 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector stream_ids; for (const auto &in_node : node->GetInAllNodes()) { @@ -398,8 +395,7 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { return kInvalidStream; } -Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph, - const vector &subgraphs) { +Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { set ops_without_label; // Check if subgraph is engine skipped and without stream label or not @@ -441,7 +437,7 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph return SUCCESS; } -bool NodeStreamUpdatePass::AreAllPredStreamsInvalid(const NodePtr &node) const { +bool UpdateForSkippedEnginePass::AreAllPredStreamsInvalid(const NodePtr &node) const { for (const auto &pre_node : node->GetInAllNodes()) { auto pre_node_desc = pre_node->GetOpDesc(); if (pre_node_desc != nullptr) { @@ -653,12 +649,14 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec if (context_.enable_single_stream) { passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); + passes.emplace_back(MakeShared()); } else { passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); + passes.emplace_back(MakeShared()); } for (auto &pass : passes) { diff --git a/ge/graph/build/logical_stream_allocator.h b/ge/graph/build/logical_stream_allocator.h index 0aebb9b4..b9aec611 100644 --- a/ge/graph/build/logical_stream_allocator.h +++ b/ge/graph/build/logical_stream_allocator.h @@ -147,15 +147,20 @@ class NodeStreamUpdatePass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass); Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; +}; - private: +// Update the stream of subgraphs to nodes. +class UpdateForSkippedEnginePass : public LogicalStreamPass { + public: + STREAM_PASS_DEFAULT_FUNC(UpdateForSkippedEnginePass); /// Optimize for case like: /// NodeA(stream1) -> Const(stream2) -> NodeB(stream1) /// To case: /// NodeA(stream1) -> Const(stream1) -> NodeB(stream1) /// Which could reduce event number (Const could be other type which belong to skipped engine subgraph) - Status UpdateForSkippedEngine(const ComputeGraphPtr &graph, const std::vector &subgraphs); + Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; + private: int64_t GetSingleInoutStream(const NodePtr &node) const; // Judge if all predecessors' streams of node are kInvalidStream bool AreAllPredStreamsInvalid(const NodePtr &node) const;