| @@ -44,19 +44,46 @@ namespace ge { | |||||
| namespace { | namespace { | ||||
| const size_t kDataOutputNum = 1; | const size_t kDataOutputNum = 1; | ||||
| bool NeedHybridModel(GeModelPtr &ge_model) { | |||||
| Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | |||||
| auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | |||||
| GE_CHECK_NOTNULL(comp_graph); | |||||
| for (const auto &node : comp_graph->GetAllNodes()) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| const auto &depends = op_desc->GetOpInferDepends(); | |||||
| if (!depends.empty()) { | |||||
| flag = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { | |||||
| bool infer_depend_flag = false; | |||||
| GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); | |||||
| auto tasks = ge_model->GetModelTaskDefPtr()->task(); | auto tasks = ge_model->GetModelTaskDefPtr()->task(); | ||||
| int32_t kernel_task_num = 0; | int32_t kernel_task_num = 0; | ||||
| for (int i = 0; i < tasks.size(); ++i) { | for (int i = 0; i < tasks.size(); ++i) { | ||||
| auto task_type = static_cast<rtModelTaskType_t>(tasks[i].type()); | auto task_type = static_cast<rtModelTaskType_t>(tasks[i].type()); | ||||
| if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { | if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { | ||||
| kernel_task_num++; | |||||
| if (kernel_task_num > 1) { | |||||
| return true; | |||||
| const auto &context = task_type == RT_MODEL_TASK_KERNEL ? tasks[i].kernel().context() : | |||||
| tasks[i].kernel_with_handle().context(); | |||||
| auto kernel_type = static_cast<ccKernelType>(context.kernel_type()); | |||||
| if (kernel_type == ccKernelType::TE) { | |||||
| if (infer_depend_flag) { | |||||
| flag = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| kernel_task_num++; | |||||
| if (kernel_task_num > 1) { | |||||
| flag = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return false; | |||||
| return SUCCESS; | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -503,7 +530,9 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & | |||||
| auto ge_model = model_helper_.GetGeModel(); | auto ge_model = model_helper_.GetGeModel(); | ||||
| GE_CHECK_NOTNULL(ge_model); | GE_CHECK_NOTNULL(ge_model); | ||||
| if (NeedHybridModel(ge_model)) { | |||||
| bool need_hybrid_model = false; | |||||
| GE_CHK_STATUS_RET(NeedHybridModel(ge_model, need_hybrid_model), "[Check][NeedHybridModel] failed."); | |||||
| if (need_hybrid_model) { | |||||
| GELOGD("Build single op HybridModel."); | GELOGD("Build single op HybridModel."); | ||||
| GE_CHK_STATUS_RET_NOLOG(hybrid::NodeExecutorManager::GetInstance().EnsureInitialized()); | GE_CHK_STATUS_RET_NOLOG(hybrid::NodeExecutorManager::GetInstance().EnsureInitialized()); | ||||
| auto root_model = model_helper_.GetGeRootModel(); | auto root_model = model_helper_.GetGeRootModel(); | ||||