From: @zhao_zhixuan Reviewed-by: @sheng-nan Signed-off-by:tags/v1.2.0
@@ -101,6 +101,7 @@ | |||||
#include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "register/custom_pass_helper.h" | |||||
namespace { | namespace { | ||||
const char *const kSummary = "Summary"; | const char *const kSummary = "Summary"; | ||||
@@ -765,10 +766,24 @@ Status GraphManager::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphManager::RunCustomPass(const GraphNodePtr &graph_node) { | |||||
ConstGraphPtr const_graph = graph_node->GetGraph(); | |||||
auto comp_graph = GraphUtils::GetComputeGraph(*const_graph); | |||||
GE_DUMP(comp_graph, "RunCustomPassBegin"); | |||||
GE_TIMESTAMP_START(RunCustomPass); | |||||
GraphPtr graph = std::const_pointer_cast<Graph>(const_graph); | |||||
GE_CHK_STATUS_RET(CustomPassHelper::Instance().Run(graph), "Graph[%s] run custom pass fail.", | |||||
comp_graph->GetName().c_str()); | |||||
GE_TIMESTAMP_END(RunCustomPass, "GraphBuilder::RunCustomPass"); | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, | Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, | ||||
GeRootModelPtr &ge_root_model, uint64_t session_id) { | GeRootModelPtr &ge_root_model, uint64_t session_id) { | ||||
GE_CHECK_NOTNULL(graph_node); | GE_CHECK_NOTNULL(graph_node); | ||||
GE_CHECK_NOTNULL(graph_node->GetGraph()); | GE_CHECK_NOTNULL(graph_node->GetGraph()); | ||||
GE_CHK_STATUS_RET_NOLOG(RunCustomPass(graph_node)); | |||||
auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph()); | auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph()); | ||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
compute_graph->SetSessionID(session_id); | compute_graph->SetSessionID(session_id); | ||||
@@ -226,6 +226,7 @@ class GraphManager { | |||||
void ParseInputsDimsForData(const std::vector<InputTensorInfo> &input_tensor); | void ParseInputsDimsForData(const std::vector<InputTensorInfo> &input_tensor); | ||||
Status ParseInputsDimsForGetNexNosinkAndData(const vector<NodePtr> &dynamic_nodes, | Status ParseInputsDimsForGetNexNosinkAndData(const vector<NodePtr> &dynamic_nodes, | ||||
const std::vector<InputTensorInfo> &input_tensor); | const std::vector<InputTensorInfo> &input_tensor); | ||||
Status RunCustomPass(const GraphNodePtr &graph_node); | |||||
Status PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, GeRootModelPtr &ge_root_model, | Status PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, GeRootModelPtr &ge_root_model, | ||||
uint64_t session_id = INVALID_SESSION_ID); | uint64_t session_id = INVALID_SESSION_ID); | ||||
@@ -1 +1 @@ | |||||
Subproject commit 11c6cf2921b6a385616a3ebc601b4431b55b07db | |||||
Subproject commit 44bcbb5ea25ada1a5393aa4c7f554d40b6859b18 |
@@ -1 +1 @@ | |||||
Subproject commit 99437c39d26624a14060307366a96b79b1d439c3 | |||||
Subproject commit 5b93b050dd7ca5b77c3001a790031d877fa10956 |
@@ -121,6 +121,7 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc" | "${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc" | ||||
"${GE_CODE_DIR}/metadef/ops/op_imp.cpp" | "${GE_CODE_DIR}/metadef/ops/op_imp.cpp" | ||||
"${GE_CODE_DIR}/metadef/register/register.cpp" | "${GE_CODE_DIR}/metadef/register/register.cpp" | ||||
"${GE_CODE_DIR}/metadef/register/register_pass.cpp" | |||||
"${GE_CODE_DIR}/metadef/register/op_kernel_registry.cpp" | "${GE_CODE_DIR}/metadef/register/op_kernel_registry.cpp" | ||||
"${GE_CODE_DIR}/metadef/register/auto_mapping_util.cpp" | "${GE_CODE_DIR}/metadef/register/auto_mapping_util.cpp" | ||||
"${GE_CODE_DIR}/metadef/register/tensor_assign.cpp" | "${GE_CODE_DIR}/metadef/register/tensor_assign.cpp" | ||||