Description:Support model_exit in GE Team:HISI_SW Feature or Bugfix:Featuretags/v1.1.0
@@ -88,6 +88,7 @@ set(TRAIN_SRC_LIST | |||
"graph/load/new_model_manager/model_utils.cc" | |||
"graph/load/new_model_manager/aipp_utils.cc" | |||
"graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
"graph/load/new_model_manager/task_info/model_exit_task_info.cc" | |||
"graph/load/new_model_manager/task_info/event_record_task_info.cc" | |||
"graph/load/new_model_manager/task_info/event_wait_task_info.cc" | |||
"graph/load/new_model_manager/task_info/fusion_start_task_info.cc" | |||
@@ -621,6 +622,7 @@ set(INFER_SRC_LIST | |||
"graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | |||
"graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||
"graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
"graph/load/new_model_manager/task_info/model_exit_task_info.cc" | |||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
"single_op/task/op_task.cc" | |||
@@ -398,6 +398,7 @@ REGISTER_OPTYPE_DEFINE(MEMCPYASYNC, "MemcpyAsync"); | |||
REGISTER_OPTYPE_DEFINE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||
REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); | |||
REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | |||
REGISTER_OPTYPE_DEFINE(MODELEXIT, "ModelExit"); | |||
REGISTER_OPTYPE_DEFINE(SEND, "Send"); | |||
REGISTER_OPTYPE_DEFINE(RECV, "Recv"); | |||
REGISTER_OPTYPE_DEFINE(ENDOFSEQUENCE, "EndOfSequence"); | |||
@@ -58,6 +58,7 @@ set(SRC_LIST | |||
"../graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | |||
"../graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||
"../graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
"../graph/load/new_model_manager/task_info/model_exit_task_info.cc" | |||
"../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
"../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
"../opskernel_manager/ops_kernel_builder_manager.cc" | |||
@@ -48,6 +48,7 @@ local_ge_executor_src_files := \ | |||
../graph/load/new_model_manager/task_info/stream_switch_task_info.cc \ | |||
../graph/load/new_model_manager/task_info/stream_switchn_task_info.cc \ | |||
../graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | |||
../graph/load/new_model_manager/task_info/model_exit_task_info.cc \ | |||
../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ | |||
../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ | |||
../opskernel_manager/ops_kernel_builder_manager.cc \ | |||
@@ -247,6 +247,7 @@ OME_HOST_SRC_FILES := \ | |||
graph/load/new_model_manager/task_info/stream_switch_task_info.cc \ | |||
graph/load/new_model_manager/task_info/stream_switchn_task_info.cc \ | |||
graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | |||
graph/load/new_model_manager/task_info/model_exit_task_info.cc \ | |||
graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ | |||
graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ | |||
single_op/task/op_task.cc \ | |||
@@ -61,6 +61,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
graph/load/new_model_manager/model_utils.cc \ | |||
graph/load/new_model_manager/aipp_utils.cc \ | |||
graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | |||
graph/load/new_model_manager/task_info/model_exit_task_info.cc \ | |||
graph/load/new_model_manager/task_info/event_record_task_info.cc \ | |||
graph/load/new_model_manager/task_info/event_wait_task_info.cc \ | |||
graph/load/new_model_manager/task_info/fusion_start_task_info.cc \ | |||
@@ -0,0 +1,54 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/load/new_model_manager/task_info/model_exit_task_info.h" | |||
#include "common/properties_manager.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "graph/load/new_model_manager/davinci_model.h" | |||
namespace ge { | |||
Status ModelExitTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||
GELOGI("InitModelExitTaskInfo Init Start."); | |||
if (davinci_model == nullptr) { | |||
GELOGE(PARAM_INVALID, "davinci_model is null!"); | |||
return PARAM_INVALID; | |||
} | |||
Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "SetStream fail, stream_id:%u", task_def.stream_id()); | |||
return ret; | |||
} | |||
model_ = davinci_model->GetRtModelHandle(); | |||
GELOGI("InitModelExitTaskInfo Init Success, model:%p, stream:%p", model_, stream_); | |||
return SUCCESS; | |||
} | |||
Status ModelExitTaskInfo::Distribute() { | |||
GELOGI("ModelExitTaskInfo Distribute Start."); | |||
rtError_t rt_ret = rtModelExit(model_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rtModelExit failed, ret: 0x%x", rt_ret); | |||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
} | |||
GELOGI("ModelExitTaskInfo Distribute Success."); | |||
return SUCCESS; | |||
} | |||
REGISTER_TASK_INFO(RT_MODEL_TASK_MODEL_EXIT, ModelExitTaskInfo); | |||
} // namespace ge |
@@ -0,0 +1,37 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MODEL_EXIT_TASK_INFO_H_ | |||
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MODEL_EXIT_TASK_INFO_H_ | |||
#include "graph/load/new_model_manager/task_info/task_info.h" | |||
namespace ge { | |||
class ModelExitTaskInfo : public TaskInfo { | |||
public: | |||
ModelExitTaskInfo() {} | |||
~ModelExitTaskInfo() override { model_ = nullptr; } | |||
Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||
Status Distribute() override; | |||
private: | |||
rtModel_t model_{nullptr}; | |||
}; | |||
} // namespace ge | |||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MODEL_EXIT_TASK_INFO_H_ |
@@ -84,6 +84,22 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||
return graph_change ? SUCCESS : NOT_CHANGED; | |||
} | |||
bool FlowCtrlPass::CheckMultiDataSet(ComputeGraphPtr &compute_graph) { | |||
int data_set_num = 0; | |||
for (auto &node : compute_graph->GetDirectNode()) { | |||
if (node == nullptr) { | |||
continue; | |||
} | |||
string type; | |||
bool is_found = AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||
if (is_found && type == "IteratorV2") { | |||
data_set_num++; | |||
} | |||
} | |||
GELOGI("The ComputeGraph contain %d dataSet.", data_set_num); | |||
return (data_set_num > 1) ? true : false; | |||
} | |||
NodePtr FlowCtrlPass::InsertOp(ComputeGraphPtr &compute_graph, const string &node_type, const string &node_name, | |||
const std::vector<GeTensorDesc> &input_list, | |||
const std::vector<GeTensorDesc> &output_list) { | |||
@@ -310,12 +326,12 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c | |||
* loopCond | |||
* | | |||
* v | |||
* switch --> Assign | |||
* switch --> Assign --> ModelExit | |||
* ^ | |||
* | | |||
* loopReset | |||
*/ | |||
// Insert Assign node | |||
// Insert Assign node and ctrl edge | |||
NodePtr assign_node = | |||
InsertAssignOp(compute_graph, ASSIGN, NODE_NAME_FLOWCTRL_LOOP_ASSIGN, loop_cond_node, loop_reset_node); | |||
if (assign_node == nullptr || switch_node == nullptr) { | |||
@@ -325,13 +341,31 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c | |||
GE_CHK_STATUS_RET(SetStreamLabel(assign_node, switch_node->GetName()), "set stream label failed"); | |||
// 3. Insert ctrl edges | |||
graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), assign_node->GetInControlAnchor()); | |||
if (add_ret != GRAPH_SUCCESS) { | |||
GELOGE(FAILED, "Add switch_node to assign_node ctrl edge failed, add_ret=%u.", add_ret); | |||
return FAILED; | |||
} | |||
// 2. Insert model exit node and add ctrl edge | |||
if (CheckMultiDataSet(compute_graph)) { | |||
GELOGI("Multi dataSae exist, model_exit node is need."); | |||
string model_exit_name = switch_node->GetName() + "_ModelExit"; | |||
NodePtr model_exit_node = InsertOp(compute_graph, MODELEXIT, model_exit_name, {}, {}); | |||
if (model_exit_node == nullptr) { | |||
GELOGE(FAILED, "Insert model_exit node:%s for IterCtrlTrueStream failed.", model_exit_name.c_str()); | |||
return FAILED; | |||
} | |||
// Must set same stream label with assign_node | |||
GE_CHK_STATUS_RET(SetStreamLabel(model_exit_node, switch_node->GetName()), "set stream label failed"); | |||
add_ret = GraphUtils::AddEdge(assign_node->GetOutControlAnchor(), model_exit_node->GetInControlAnchor()); | |||
if (add_ret != GRAPH_SUCCESS) { | |||
GELOGE(FAILED, "Add assign_node to model_exit_node ctrl edge failed, add_ret=%u.", add_ret); | |||
return FAILED; | |||
} | |||
} | |||
GELOGI("CreateIterCtrlFalseBranch success."); | |||
return SUCCESS; | |||
} | |||
@@ -134,6 +134,14 @@ class FlowCtrlPass : public GraphPass { | |||
/// Other: failed | |||
/// | |||
Status AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &loop_after_node); | |||
/// | |||
/// add special iterator ctrl nodes(small cycle). | |||
/// @param compute_graph graph | |||
/// @return true: two or more dataSet exist | |||
/// false: only one dataSet exist | |||
/// | |||
bool CheckMultiDataSet(ComputeGraphPtr &compute_graph); | |||
}; | |||
} // namespace ge | |||
@@ -449,6 +449,7 @@ REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | |||
REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||
REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | |||
REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | |||
REGISTER_OPTYPE_DECLARE(MODELEXIT, "ModelExit"); | |||
REGISTER_OPTYPE_DECLARE(SEND, "Send"); | |||
REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | |||
REGISTER_OPTYPE_DECLARE(ENDOFSEQUENCE, "EndOfSequence"); | |||
@@ -190,6 +190,7 @@ file(GLOB_RECURSE DISTINCT_GRAPH_LOAD_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR} | |||
"${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc" | |||
"${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | |||
"${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
"${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/model_exit_task_info.cc" | |||
"${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
"${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
"${GE_SOURCE_DIR}/src/ge/graph/load/output/output.cc" | |||