diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index fa315516..51ef2009 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -137,6 +137,7 @@ set(TRAIN_SRC_LIST "graph/passes/atomic_addr_clean_pass.cc" "graph/passes/mark_same_addr_pass.cc" "graph/passes/mark_graph_unknown_status_pass.cc" + "graph/passes/mark_agnostic_pass.cc" "graph/partition/dynamic_shape_partition.cc" "graph/partition/stage_partition.cc" "graph/passes/base_pass.cc" @@ -488,6 +489,7 @@ set(INFER_SRC_LIST "graph/passes/atomic_addr_clean_pass.cc" "graph/passes/mark_same_addr_pass.cc" "graph/passes/mark_graph_unknown_status_pass.cc" + "graph/passes/mark_agnostic_pass.cc" "graph/common/omg_util.cc" "graph/common/bcast.cc" "graph/common/local_context.cc" diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index ac106346..bbf18b90 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -109,6 +109,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/atomic_addr_clean_pass.cc \ graph/passes/mark_same_addr_pass.cc \ graph/passes/mark_graph_unknown_status_pass.cc \ + graph/passes/mark_agnostic_pass.cc \ graph/common/omg_util.cc \ graph/common/bcast.cc \ graph/common/local_context.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 6c448a46..2af4cbc6 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -110,6 +110,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/atomic_addr_clean_pass.cc \ graph/passes/mark_same_addr_pass.cc \ graph/passes/mark_graph_unknown_status_pass.cc \ + graph/passes/mark_agnostic_pass.cc \ graph/partition/dynamic_shape_partition.cc \ graph/partition/stage_partition.cc \ graph/passes/base_pass.cc \ diff --git a/ge/graph/passes/mark_agnostic_pass.cc b/ge/graph/passes/mark_agnostic_pass.cc index 0275bc9f..77fa64fb 100644 --- a/ge/graph/passes/mark_agnostic_pass.cc +++ b/ge/graph/passes/mark_agnostic_pass.cc @@ -15,20 +15,40 @@ */ #include "graph/passes/mark_agnostic_pass.h" -#include "utils/node_utils.h" +#include "graph/utils/node_utils.h" namespace ge { Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { for (const auto &node : graph->GetDirectNode()) { auto node_type = NodeUtils::GetNodeType(*node); if (node_type == SWITCH || node_type == REFSWITCH || node_type == SWITCHN) { - GELOGD("Mark format agnostic for switch ndoe %s", node->GetName().c_str()); + GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str()); + const OpDescPtr op_desc = node->GetOpDesc(); + const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0); + if (op_tensor == nullptr) { + GELOGD("Op: %s, Index:0,has no input", node->GetName().c_str()); + continue; + } + AttrUtils::SetInt(op_tensor, "_format_continuous", 1); + AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); + AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector({1})); + continue; + } + if (node_type == IDENTITY) { + GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector({1})); continue; } if (node_type == MERGE || node_type == REFMERGE) { - GELOGD("Mark format agnostic for merge node %s", node->GetName().c_str()); + GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); + const OpDescPtr op_desc = node->GetOpDesc(); + const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); + if (op_tensor == nullptr) { + GELOGD("Op: %s, Index:0,has no output", node->GetName().c_str()); + continue; + } + AttrUtils::SetInt(op_tensor, "_format_continuous", 1); AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_output", std::vector({1})); continue; @@ -36,4 +56,4 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { } return SUCCESS; } -} \ No newline at end of file +} // namespace ge \ No newline at end of file diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index f90c0d80..20964b6c 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -92,6 +92,7 @@ #include "graph/passes/unused_op_remove_pass.h" #include "graph/passes/var_is_initialized_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" +#include "graph/passes/mark_agnostic_pass.h" #include "graph/preprocess/insert_op/util_insert_aipp_op.h" #include "graph/types.h" #include "graph/utils/tensor_utils.h" @@ -1626,6 +1627,7 @@ Status GraphPrepare::PrepareOptimize() { try { (void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); (void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); + (void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass", new MarkAgnosticPass); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR;