|
@@ -363,13 +363,10 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// Update stream id for nodes belong to skipped engine subgraph |
|
|
|
|
|
GE_CHK_STATUS_RET(UpdateForSkippedEngine(graph, subgraphs)); |
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { |
|
|
|
|
|
|
|
|
int64_t UpdateForSkippedEnginePass::GetSingleInoutStream(const NodePtr &node) const { |
|
|
set<int64_t> stream_ids; |
|
|
set<int64_t> stream_ids; |
|
|
|
|
|
|
|
|
for (const auto &in_node : node->GetInAllNodes()) { |
|
|
for (const auto &in_node : node->GetInAllNodes()) { |
|
@@ -398,8 +395,7 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { |
|
|
return kInvalidStream; |
|
|
return kInvalidStream; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph, |
|
|
|
|
|
const vector<SubgraphPtr> &subgraphs) { |
|
|
|
|
|
|
|
|
Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { |
|
|
set<OpDescPtr> ops_without_label; |
|
|
set<OpDescPtr> ops_without_label; |
|
|
|
|
|
|
|
|
// Check if subgraph is engine skipped and without stream label or not |
|
|
// Check if subgraph is engine skipped and without stream label or not |
|
@@ -441,7 +437,7 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool NodeStreamUpdatePass::AreAllPredStreamsInvalid(const NodePtr &node) const { |
|
|
|
|
|
|
|
|
bool UpdateForSkippedEnginePass::AreAllPredStreamsInvalid(const NodePtr &node) const { |
|
|
for (const auto &pre_node : node->GetInAllNodes()) { |
|
|
for (const auto &pre_node : node->GetInAllNodes()) { |
|
|
auto pre_node_desc = pre_node->GetOpDesc(); |
|
|
auto pre_node_desc = pre_node->GetOpDesc(); |
|
|
if (pre_node_desc != nullptr) { |
|
|
if (pre_node_desc != nullptr) { |
|
@@ -653,12 +649,14 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec |
|
|
if (context_.enable_single_stream) { |
|
|
if (context_.enable_single_stream) { |
|
|
passes.emplace_back(MakeShared<SingleStreamPass>()); |
|
|
passes.emplace_back(MakeShared<SingleStreamPass>()); |
|
|
passes.emplace_back(MakeShared<NodeStreamUpdatePass>()); |
|
|
passes.emplace_back(MakeShared<NodeStreamUpdatePass>()); |
|
|
|
|
|
passes.emplace_back(MakeShared<UpdateForSkippedEnginePass>()); |
|
|
} else { |
|
|
} else { |
|
|
passes.emplace_back(MakeShared<AssignByLabelPass>()); |
|
|
passes.emplace_back(MakeShared<AssignByLabelPass>()); |
|
|
passes.emplace_back(MakeShared<IndependentStreamPass>()); |
|
|
passes.emplace_back(MakeShared<IndependentStreamPass>()); |
|
|
passes.emplace_back(MakeShared<AssignByDependencyPass>()); |
|
|
passes.emplace_back(MakeShared<AssignByDependencyPass>()); |
|
|
passes.emplace_back(MakeShared<NodeStreamUpdatePass>()); |
|
|
passes.emplace_back(MakeShared<NodeStreamUpdatePass>()); |
|
|
passes.emplace_back(MakeShared<AllReduceParallelPass>()); |
|
|
passes.emplace_back(MakeShared<AllReduceParallelPass>()); |
|
|
|
|
|
passes.emplace_back(MakeShared<UpdateForSkippedEnginePass>()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (auto &pass : passes) { |
|
|
for (auto &pass : passes) { |
|
|