|
|
@@ -131,6 +131,22 @@ InputRWType GetInputRwTypeInConflict(const std::set<int> &rw_type_set) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool IsSubgraphInputNode(const NodePtr &node) { |
|
|
|
if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetType() != DATA) || |
|
|
|
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsSubgraphOutputNode(const NodePtr &node) { |
|
|
|
if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetType() != NETOUTPUT) || |
|
|
|
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { |
|
|
|
if (src_node.GetOpDesc() == nullptr) { |
|
|
|
return nullptr; |
|
|
@@ -377,7 +393,7 @@ bool CheckIdentityIsNearSubgraph(const Node &node) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// near entrance of subgraph |
|
|
|
if (in_node->GetType() == DATA && NodeUtils::IsSubgraphInput(in_node)) { |
|
|
|
if (IsSubgraphInputNode(in_node)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
// near subgraph |
|
|
@@ -392,7 +408,7 @@ bool CheckIdentityIsNearSubgraph(const Node &node) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// near output of subgraph |
|
|
|
if (out_node->GetType() == NETOUTPUT && NodeUtils::IsSubgraphOutput(out_node)) { |
|
|
|
if (IsSubgraphOutputNode(out_node)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
// near subgraph |
|
|
|