| @@ -88,6 +88,7 @@ set(TRAIN_SRC_LIST | |||
| "graph/load/new_model_manager/model_utils.cc" | |||
| "graph/load/new_model_manager/aipp_utils.cc" | |||
| "graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/model_exit_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/event_record_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/event_wait_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/fusion_start_task_info.cc" | |||
| @@ -621,6 +622,7 @@ set(INFER_SRC_LIST | |||
| "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/model_exit_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
| "single_op/task/op_task.cc" | |||
| @@ -398,6 +398,7 @@ REGISTER_OPTYPE_DEFINE(MEMCPYASYNC, "MemcpyAsync"); | |||
| REGISTER_OPTYPE_DEFINE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||
| REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); | |||
| REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | |||
| REGISTER_OPTYPE_DEFINE(MODELEXIT, "ModelExit"); | |||
| REGISTER_OPTYPE_DEFINE(SEND, "Send"); | |||
| REGISTER_OPTYPE_DEFINE(RECV, "Recv"); | |||
| REGISTER_OPTYPE_DEFINE(ENDOFSEQUENCE, "EndOfSequence"); | |||
| @@ -58,6 +58,7 @@ set(SRC_LIST | |||
| "../graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | |||
| "../graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||
| "../graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
| "../graph/load/new_model_manager/task_info/model_exit_task_info.cc" | |||
| "../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
| "../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
| "../opskernel_manager/ops_kernel_builder_manager.cc" | |||
| @@ -48,6 +48,7 @@ local_ge_executor_src_files := \ | |||
| ../graph/load/new_model_manager/task_info/stream_switch_task_info.cc \ | |||
| ../graph/load/new_model_manager/task_info/stream_switchn_task_info.cc \ | |||
| ../graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | |||
| ../graph/load/new_model_manager/task_info/model_exit_task_info.cc \ | |||
| ../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ | |||
| ../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ | |||
| ../opskernel_manager/ops_kernel_builder_manager.cc \ | |||
| @@ -247,6 +247,7 @@ OME_HOST_SRC_FILES := \ | |||
| graph/load/new_model_manager/task_info/stream_switch_task_info.cc \ | |||
| graph/load/new_model_manager/task_info/stream_switchn_task_info.cc \ | |||
| graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | |||
| graph/load/new_model_manager/task_info/model_exit_task_info.cc \ | |||
| graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ | |||
| graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ | |||
| single_op/task/op_task.cc \ | |||
| @@ -61,6 +61,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| graph/load/new_model_manager/model_utils.cc \ | |||
| graph/load/new_model_manager/aipp_utils.cc \ | |||
| graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | |||
| graph/load/new_model_manager/task_info/model_exit_task_info.cc \ | |||
| graph/load/new_model_manager/task_info/event_record_task_info.cc \ | |||
| graph/load/new_model_manager/task_info/event_wait_task_info.cc \ | |||
| graph/load/new_model_manager/task_info/fusion_start_task_info.cc \ | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/load/new_model_manager/task_info/model_exit_task_info.h" | |||
| #include "common/properties_manager.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "graph/load/new_model_manager/davinci_model.h" | |||
| namespace ge { | |||
| Status ModelExitTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||
| GELOGI("InitModelExitTaskInfo Init Start."); | |||
| if (davinci_model == nullptr) { | |||
| GELOGE(PARAM_INVALID, "davinci_model is null!"); | |||
| return PARAM_INVALID; | |||
| } | |||
| Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "SetStream fail, stream_id:%u", task_def.stream_id()); | |||
| return ret; | |||
| } | |||
| model_ = davinci_model->GetRtModelHandle(); | |||
| GELOGI("InitModelExitTaskInfo Init Success, model:%p, stream:%p", model_, stream_); | |||
| return SUCCESS; | |||
| } | |||
| Status ModelExitTaskInfo::Distribute() { | |||
| GELOGI("ModelExitTaskInfo Distribute Start."); | |||
| rtError_t rt_ret = rtModelExit(model_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rtModelExit failed, ret: 0x%x", rt_ret); | |||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||
| } | |||
| GELOGI("ModelExitTaskInfo Distribute Success."); | |||
| return SUCCESS; | |||
| } | |||
| REGISTER_TASK_INFO(RT_MODEL_TASK_MODEL_EXIT, ModelExitTaskInfo); | |||
| } // namespace ge | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MODEL_EXIT_TASK_INFO_H_ | |||
| #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MODEL_EXIT_TASK_INFO_H_ | |||
| #include "graph/load/new_model_manager/task_info/task_info.h" | |||
| namespace ge { | |||
| class ModelExitTaskInfo : public TaskInfo { | |||
| public: | |||
| ModelExitTaskInfo() {} | |||
| ~ModelExitTaskInfo() override { model_ = nullptr; } | |||
| Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||
| Status Distribute() override; | |||
| private: | |||
| rtModel_t model_{nullptr}; | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MODEL_EXIT_TASK_INFO_H_ | |||
| @@ -84,6 +84,22 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||
| return graph_change ? SUCCESS : NOT_CHANGED; | |||
| } | |||
| bool FlowCtrlPass::CheckMultiDataSet(ComputeGraphPtr &compute_graph) { | |||
| int data_set_num = 0; | |||
| for (auto &node : compute_graph->GetDirectNode()) { | |||
| if (node == nullptr) { | |||
| continue; | |||
| } | |||
| string type; | |||
| bool is_found = AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||
| if (is_found && type == "IteratorV2") { | |||
| data_set_num++; | |||
| } | |||
| } | |||
| GELOGI("The ComputeGraph contain %d dataSet.", data_set_num); | |||
| return (data_set_num > 1) ? true : false; | |||
| } | |||
| NodePtr FlowCtrlPass::InsertOp(ComputeGraphPtr &compute_graph, const string &node_type, const string &node_name, | |||
| const std::vector<GeTensorDesc> &input_list, | |||
| const std::vector<GeTensorDesc> &output_list) { | |||
| @@ -310,12 +326,12 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c | |||
| * loopCond | |||
| * | | |||
| * v | |||
| * switch --> Assign | |||
| * switch --> Assign --> active --> ModelExit | |||
| * ^ | |||
| * | | |||
| * loopReset | |||
| */ | |||
| // Insert Assign node | |||
| // Insert Assign node and ctrl edge | |||
| NodePtr assign_node = | |||
| InsertAssignOp(compute_graph, ASSIGN, NODE_NAME_FLOWCTRL_LOOP_ASSIGN, loop_cond_node, loop_reset_node); | |||
| if (assign_node == nullptr || switch_node == nullptr) { | |||
| @@ -325,13 +341,49 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c | |||
| GE_CHK_STATUS_RET(SetStreamLabel(assign_node, switch_node->GetName()), "set stream label failed"); | |||
| // 3. Insert ctrl edges | |||
| graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), assign_node->GetInControlAnchor()); | |||
| if (add_ret != GRAPH_SUCCESS) { | |||
| GELOGE(FAILED, "Add switch_node to assign_node ctrl edge failed, add_ret=%u.", add_ret); | |||
| return FAILED; | |||
| } | |||
| if (CheckMultiDataSet(compute_graph)) { | |||
| GELOGI("Multi dataSae exist, model_exit node is need."); | |||
| // 2. Insert active node and add ctrl edge | |||
| string active_name = switch_node->GetName() + "_StreamExitActive"; | |||
| NodePtr active_node = InsertOp(compute_graph, STREAMACTIVE, active_name, {}, {}); | |||
| if (active_node == nullptr) { | |||
| GELOGE(FAILED, "Insert stream active node:%s for IterCtrlTrueStream failed.", active_name.c_str()); | |||
| return FAILED; | |||
| } | |||
| GE_CHK_STATUS_RET(SetStreamLabel(active_node, switch_node->GetName()), "set stream label failed"); | |||
| GE_IF_BOOL_EXEC(!AttrUtils::SetBool(active_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, true), | |||
| DOMI_LOGE("set ATTR_NAME_IS_LOOP_ACTIVE failed"); return FAILED); | |||
| string model_exit_name = switch_node->GetName() + "_ModelExit"; | |||
| GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { model_exit_name }), "set active label list failed"); | |||
| add_ret = GraphUtils::AddEdge(assign_node->GetOutControlAnchor(), active_node->GetInControlAnchor()); | |||
| if (add_ret != GRAPH_SUCCESS) { | |||
| GELOGE(FAILED, "Add assign_node to active_node ctrl edge failed, add_ret=%u.", add_ret); | |||
| return FAILED; | |||
| } | |||
| // 3. Insert model exit node and add ctrl edge | |||
| NodePtr model_exit_node = InsertOp(compute_graph, MODELEXIT, model_exit_name, {}, {}); | |||
| if (model_exit_node == nullptr) { | |||
| GELOGE(FAILED, "Insert model_exit node:%s for IterCtrlTrueStream failed.", model_exit_name.c_str()); | |||
| return FAILED; | |||
| } | |||
| GE_CHK_STATUS_RET(SetStreamLabel(model_exit_node, model_exit_name), "set stream label failed"); | |||
| add_ret = GraphUtils::AddEdge(active_node->GetOutControlAnchor(), model_exit_node->GetInControlAnchor()); | |||
| if (add_ret != GRAPH_SUCCESS) { | |||
| GELOGE(FAILED, "Add active_node to model_exit_node ctrl edge failed, add_ret=%u.", add_ret); | |||
| return FAILED; | |||
| } | |||
| } | |||
| GELOGI("CreateIterCtrlFalseBranch success."); | |||
| return SUCCESS; | |||
| } | |||
| @@ -134,6 +134,14 @@ class FlowCtrlPass : public GraphPass { | |||
| /// Other: failed | |||
| /// | |||
| Status AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &loop_after_node); | |||
| /// | |||
| /// add special iterator ctrl nodes(small cycle). | |||
| /// @param compute_graph graph | |||
| /// @return true: two or more dataSet exist | |||
| /// false: only one dataSet exist | |||
| /// | |||
| bool CheckMultiDataSet(ComputeGraphPtr &compute_graph); | |||
| }; | |||
| } // namespace ge | |||
| @@ -449,6 +449,7 @@ REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | |||
| REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||
| REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | |||
| REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | |||
| REGISTER_OPTYPE_DECLARE(MODELEXIT, "ModelExit"); | |||
| REGISTER_OPTYPE_DECLARE(SEND, "Send"); | |||
| REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | |||
| REGISTER_OPTYPE_DECLARE(ENDOFSEQUENCE, "EndOfSequence"); | |||
| @@ -190,6 +190,7 @@ file(GLOB_RECURSE DISTINCT_GRAPH_LOAD_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc" | |||
| "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | |||
| "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
| "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/model_exit_task_info.cc" | |||
| "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
| "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
| "${GE_SOURCE_DIR}/src/ge/graph/load/output/output.cc" | |||
| @@ -1,18 +1,18 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| */ | |||
| #ifndef __CCE_RUNTIME_BASE_H__ | |||
| #define __CCE_RUNTIME_BASE_H__ | |||
| @@ -100,6 +100,9 @@ typedef enum tagRtError { | |||
| RT_ERROR_MODEL_ID, | |||
| RT_ERROR_MODEL_EXE_FAILED, | |||
| RT_ERROR_END_OF_SEQUENCE, // end of sequence | |||
| RT_ERROR_MODEL_EXIT, | |||
| RT_ERROR_MODEL_EXIT_STREAM_UNBIND, | |||
| RT_ERROR_MODEL_EXIT_ID, | |||
| RT_ERROR_EVENT_BASE = 0x07050000, | |||
| RT_ERROR_EVENT_NULL, | |||
| @@ -1,18 +1,18 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| */ | |||
| #ifndef __CCE_RUNTIME_MODEL_H__ | |||
| #define __CCE_RUNTIME_MODEL_H__ | |||
| @@ -49,6 +49,7 @@ typedef enum tagModelTaskType { | |||
| RT_MODEL_TASK_MEMCPY_ADDR_ASYNC, | |||
| RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX, | |||
| RT_MODEL_TASK_STREAM_LABEL_GOTO, | |||
| RT_MODEL_TASK_MODEL_EXIT, | |||
| } rtModelTaskType_t; | |||
| typedef enum tagModelStreamType { | |||
| @@ -224,6 +225,13 @@ typedef struct tagrtModelEndGraphTaskInfo { | |||
| uint32_t reserved[8]; | |||
| } rtModelEndGraphTaskInfo_t; | |||
| typedef struct tagrtModelExitInfo { | |||
| uint32_t modelId; | |||
| uint32_t streamId; | |||
| uint32_t reserved[8]; | |||
| } rtModelExitTaskInfo_t; | |||
| typedef struct tagrtStreamLabelSwitchByIndexTask_t { | |||
| uint64_t indexPtr; | |||
| uint64_t labelInfoPtr; | |||
| @@ -256,6 +264,7 @@ typedef struct tagTaskInfo { | |||
| rtRdmaSendTaskInfo_t rdmaSendTask; | |||
| rtRdmaDbSendTaskInfo_t rdmaDbSendTask; | |||
| rtModelEndGraphTaskInfo_t modelEndGraphTask; | |||
| rtModelExitTaskInfo_t modelExitTask; | |||
| rtStreamSwitchNTaskInfo_t streamSwitchNTask; | |||
| rtStreamLabelSwitchByIndexTask_t streamLabelSwitchIndexTask; | |||
| rtStreamLabelGotoTask_t streamLabelGotoTask; | |||
| @@ -389,6 +398,16 @@ RTS_API rtError_t rtModelExecutorSet(rtModel_t model, uint8_t flags); | |||
| */ | |||
| RTS_API rtError_t rtModelAbort(rtModel_t model); | |||
| /** | |||
| * @ingroup rt_model | |||
| * @brief end graph task to model default stream | |||
| * @param [in] model model to execute | |||
| * @param [in] end graph stream | |||
| * @return RT_ERROR_NONE for ok | |||
| * @return RT_ERROR_INVALID_VALUE for error input | |||
| */ | |||
| RTS_API rtError_t rtModelExit(rtModel_t model, rtStream_t stream); | |||
| /** | |||
| * @ingroup rt_model | |||
| * @brief bind queue | |||