diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index fdb825a9..8d8e48ad 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -131,6 +131,22 @@ InputRWType GetInputRwTypeInConflict(const std::set &rw_type_set) { } } +bool IsSubgraphInputNode(const NodePtr &node) { + if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetType() != DATA) || + (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { + return false; + } + return true; +} + +bool IsSubgraphOutputNode(const NodePtr &node) { + if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetType() != NETOUTPUT) || + (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { + return false; + } + return true; +} + NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { if (src_node.GetOpDesc() == nullptr) { return nullptr; @@ -377,7 +393,7 @@ bool CheckIdentityIsNearSubgraph(const Node &node) { continue; } // near entrance of subgraph - if (in_node->GetType() == DATA && NodeUtils::IsSubgraphInput(in_node)) { + if (IsSubgraphInputNode(in_node)) { return true; } // near subgraph @@ -392,7 +408,7 @@ bool CheckIdentityIsNearSubgraph(const Node &node) { continue; } // near output of subgraph - if (out_node->GetType() == NETOUTPUT && NodeUtils::IsSubgraphOutput(out_node)) { + if (IsSubgraphOutputNode(out_node)) { return true; } // near subgraph diff --git a/metadef b/metadef index c9b69607..be949d5f 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit c9b6960725036291ed2328f5751beb4f01247526 +Subproject commit be949d5ff32baec332aa8765d2b211334ae84dbf diff --git a/parser b/parser index 9e051f61..d865fa6e 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit 9e051f61274f0655c8e1d9a1a8f481c051063dae +Subproject commit d865fa6e67c00c536e6df2f86d4912c1f1feff4c