Browse Source

!836 Custom pass register.

From: @zhao_zhixuan
Reviewed-by: @sheng-nan
Signed-off-by:
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
6009e647a7
5 changed files with 19 additions and 2 deletions
  1. +15
    -0
      ge/graph/manager/graph_manager.cc
  2. +1
    -0
      ge/graph/manager/graph_manager.h
  3. +1
    -1
      metadef
  4. +1
    -1
      parser
  5. +1
    -0
      tests/ut/ge/CMakeLists.txt

+ 15
- 0
ge/graph/manager/graph_manager.cc View File

@@ -101,6 +101,7 @@
#include "graph/common/local_context.h"
#include "graph/common/omg_util.h"
#include "common/formats/utils/formats_trans_utils.h"
#include "register/custom_pass_helper.h"

namespace {
const char *const kSummary = "Summary";
@@ -765,10 +766,24 @@ Status GraphManager::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint
return SUCCESS;
}

Status GraphManager::RunCustomPass(const GraphNodePtr &graph_node) {
ConstGraphPtr const_graph = graph_node->GetGraph();
auto comp_graph = GraphUtils::GetComputeGraph(*const_graph);
GE_DUMP(comp_graph, "RunCustomPassBegin");

GE_TIMESTAMP_START(RunCustomPass);
GraphPtr graph = std::const_pointer_cast<Graph>(const_graph);
GE_CHK_STATUS_RET(CustomPassHelper::Instance().Run(graph), "Graph[%s] run custom pass fail.",
comp_graph->GetName().c_str());
GE_TIMESTAMP_END(RunCustomPass, "GraphBuilder::RunCustomPass");
return SUCCESS;
}

Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs,
GeRootModelPtr &ge_root_model, uint64_t session_id) {
GE_CHECK_NOTNULL(graph_node);
GE_CHECK_NOTNULL(graph_node->GetGraph());
GE_CHK_STATUS_RET_NOLOG(RunCustomPass(graph_node));
auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph());
GE_CHECK_NOTNULL(compute_graph);
compute_graph->SetSessionID(session_id);


+ 1
- 0
ge/graph/manager/graph_manager.h View File

@@ -226,6 +226,7 @@ class GraphManager {
void ParseInputsDimsForData(const std::vector<InputTensorInfo> &input_tensor);
Status ParseInputsDimsForGetNexNosinkAndData(const vector<NodePtr> &dynamic_nodes,
const std::vector<InputTensorInfo> &input_tensor);
Status RunCustomPass(const GraphNodePtr &graph_node);
Status PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, GeRootModelPtr &ge_root_model,
uint64_t session_id = INVALID_SESSION_ID);



+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 11c6cf2921b6a385616a3ebc601b4431b55b07db
Subproject commit 44bcbb5ea25ada1a5393aa4c7f554d40b6859b18

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit 99437c39d26624a14060307366a96b79b1d439c3
Subproject commit 5b93b050dd7ca5b77c3001a790031d877fa10956

+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -121,6 +121,7 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc"
"${GE_CODE_DIR}/metadef/ops/op_imp.cpp"
"${GE_CODE_DIR}/metadef/register/register.cpp"
"${GE_CODE_DIR}/metadef/register/register_pass.cpp"
"${GE_CODE_DIR}/metadef/register/op_kernel_registry.cpp"
"${GE_CODE_DIR}/metadef/register/auto_mapping_util.cpp"
"${GE_CODE_DIR}/metadef/register/tensor_assign.cpp"


Loading…
Cancel
Save