Browse Source

add mark agnostic pass

tags/v1.1.0
zhou_chao1993 3 years ago
parent
commit
60f30d7ceb
5 changed files with 30 additions and 4 deletions
  1. +2
    -0
      ge/CMakeLists.txt
  2. +1
    -0
      ge/ge_inference.mk
  3. +1
    -0
      ge/ge_runner.mk
  4. +24
    -4
      ge/graph/passes/mark_agnostic_pass.cc
  5. +2
    -0
      ge/graph/preprocess/graph_preprocess.cc

+ 2
- 0
ge/CMakeLists.txt View File

@@ -137,6 +137,7 @@ set(TRAIN_SRC_LIST
"graph/passes/atomic_addr_clean_pass.cc"
"graph/passes/mark_same_addr_pass.cc"
"graph/passes/mark_graph_unknown_status_pass.cc"
"graph/passes/mark_agnostic_pass.cc"
"graph/partition/dynamic_shape_partition.cc"
"graph/partition/stage_partition.cc"
"graph/passes/base_pass.cc"
@@ -488,6 +489,7 @@ set(INFER_SRC_LIST
"graph/passes/atomic_addr_clean_pass.cc"
"graph/passes/mark_same_addr_pass.cc"
"graph/passes/mark_graph_unknown_status_pass.cc"
"graph/passes/mark_agnostic_pass.cc"
"graph/common/omg_util.cc"
"graph/common/bcast.cc"
"graph/common/local_context.cc"


+ 1
- 0
ge/ge_inference.mk View File

@@ -109,6 +109,7 @@ OMG_HOST_SRC_FILES := \
graph/passes/atomic_addr_clean_pass.cc \
graph/passes/mark_same_addr_pass.cc \
graph/passes/mark_graph_unknown_status_pass.cc \
graph/passes/mark_agnostic_pass.cc \
graph/common/omg_util.cc \
graph/common/bcast.cc \
graph/common/local_context.cc \


+ 1
- 0
ge/ge_runner.mk View File

@@ -110,6 +110,7 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/atomic_addr_clean_pass.cc \
graph/passes/mark_same_addr_pass.cc \
graph/passes/mark_graph_unknown_status_pass.cc \
graph/passes/mark_agnostic_pass.cc \
graph/partition/dynamic_shape_partition.cc \
graph/partition/stage_partition.cc \
graph/passes/base_pass.cc \


+ 24
- 4
ge/graph/passes/mark_agnostic_pass.cc View File

@@ -15,20 +15,40 @@
*/
#include "graph/passes/mark_agnostic_pass.h"

#include "utils/node_utils.h"
#include "graph/utils/node_utils.h"

namespace ge {
Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
for (const auto &node : graph->GetDirectNode()) {
auto node_type = NodeUtils::GetNodeType(*node);
if (node_type == SWITCH || node_type == REFSWITCH || node_type == SWITCHN) {
GELOGD("Mark format agnostic for switch ndoe %s", node->GetName().c_str());
GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str());
const OpDescPtr op_desc = node->GetOpDesc();
const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0);
if (op_tensor == nullptr) {
GELOGD("Op: %s, Index:0,has no input", node->GetName().c_str());
continue;
}
AttrUtils::SetInt(op_tensor, "_format_continuous", 1);
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1);
AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1}));
continue;
}
if (node_type == IDENTITY) {
GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str());
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1);
AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1}));
continue;
}
if (node_type == MERGE || node_type == REFMERGE) {
GELOGD("Mark format agnostic for merge node %s", node->GetName().c_str());
GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str());
const OpDescPtr op_desc = node->GetOpDesc();
const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0);
if (op_tensor == nullptr) {
GELOGD("Op: %s, Index:0,has no output", node->GetName().c_str());
continue;
}
AttrUtils::SetInt(op_tensor, "_format_continuous", 1);
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1);
AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_output", std::vector<int64_t>({1}));
continue;
@@ -36,4 +56,4 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) {
}
return SUCCESS;
}
}
} // namespace ge

+ 2
- 0
ge/graph/preprocess/graph_preprocess.cc View File

@@ -92,6 +92,7 @@
#include "graph/passes/unused_op_remove_pass.h"
#include "graph/passes/var_is_initialized_op_pass.h"
#include "graph/passes/variable_prepare_op_pass.h"
#include "graph/passes/mark_agnostic_pass.h"
#include "graph/preprocess/insert_op/util_insert_aipp_op.h"
#include "graph/types.h"
#include "graph/utils/tensor_utils.h"
@@ -1626,6 +1627,7 @@ Status GraphPrepare::PrepareOptimize() {
try {
(void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass);
(void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass);
(void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass", new MarkAgnosticPass);
} catch (std::bad_alloc &e) {
GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs.");
return INTERNAL_ERROR;


Loading…
Cancel
Save