|
|
@@ -16,6 +16,7 @@ |
|
|
|
#include "graph/passes/mark_agnostic_pass.h" |
|
|
|
|
|
|
|
#include "graph/utils/node_utils.h" |
|
|
|
#include "graph/utils/tensor_utils.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { |
|
|
@@ -47,6 +48,16 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { |
|
|
|
} |
|
|
|
if (node_type == MERGE) { |
|
|
|
GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); |
|
|
|
auto in_nodes = node->GetInAllNodes(); |
|
|
|
vector<NodePtr> input_nodes(in_nodes.begin(), in_nodes.end()); |
|
|
|
/// Enter-----------+ |
|
|
|
/// +-> Merge |
|
|
|
/// NextIteration---+ |
|
|
|
if (input_nodes.size() == 2) { |
|
|
|
if (input_nodes[0]->GetType() == ENTER && input_nodes[1]->GetType() == NEXTITERATION) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
const OpDescPtr op_desc = node->GetOpDesc(); |
|
|
|
const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); |
|
|
|
if (op_tensor == nullptr) { |
|
|
|