Browse Source

add mark agnostic pass

tags/v1.1.0
zhou_chao1993 4 years ago
parent
commit
0c111a4da6
2 changed files with 13 additions and 0 deletions
  1. +11
    -0
      ge/graph/passes/mark_agnostic_pass.cc
  2. +2
    -0
      ge/graph/preprocess/graph_preprocess.cc

+ 11
- 0
ge/graph/passes/mark_agnostic_pass.cc View File

@@ -16,6 +16,7 @@
#include "graph/passes/mark_agnostic_pass.h"

#include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"

namespace ge {
Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
@@ -47,6 +48,16 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
}
if (node_type == MERGE) {
GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str());
auto in_nodes = node->GetInAllNodes();
vector<NodePtr> input_nodes(in_nodes.begin(), in_nodes.end());
/// Enter-----------+
/// +-> Merge
/// NextIteration---+
if (input_nodes.size() == 2) {
if (input_nodes[0]->GetType() == ENTER && input_nodes[1]->GetType() == NEXTITERATION) {
continue;
}
}
const OpDescPtr op_desc = node->GetOpDesc();
const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0);
if (op_tensor == nullptr) {


+ 2
- 0
ge/graph/preprocess/graph_preprocess.cc View File

@@ -117,6 +117,7 @@
#include "graph/passes/variable_op_pass.h"
#include "graph/passes/variable_prepare_op_pass.h"
#include "graph/passes/variable_ref_delete_op_pass.h"
#include "graph/passes/mark_agnostic_pass.h"


namespace ge {
@@ -1700,6 +1701,7 @@ Status GraphPrepare::PrepareOptimize() {
try {
(void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass);
(void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass);
(void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass" , new MarkAgnosticPass);
} catch (std::bad_alloc &e) {
GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs.");
return INTERNAL_ERROR;


Loading…
Cancel
Save